# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for keras functional model."""

from keras import backend
from keras.engine import input_layer as input_layer_module
from keras.engine import keras_tensor
from keras.engine import node as node_module

import tensorflow.compat.v2 as tf

_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG = (
    'Found unexpected instance while processing input tensors for keras '
    'functional model. Expecting KerasTensor which is from tf.keras.Input() '
    'or output from keras layer call(). Got: {}')


def is_input_keras_tensor(tensor):
  """Check if tensor is directly generated from `tf.keras.Input`.

  This check is useful when constructing the functional model, since we will
  need to clone Nodes and KerasTensors if the model is building from non input
  tensor.

  Args:
    tensor: A `KerasTensor` as inputs to the functional model.

  Returns:
    bool. Whether the tensor is directly generated from `tf.keras.Input`.

  Raises:
    ValueError: if the tensor is not a KerasTensor instance.
  """
  if not node_module.is_keras_tensor(tensor):
    raise ValueError(_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG.format(tensor))
  return tensor.node.is_input


def find_nodes_by_inputs_and_outputs(inputs, outputs):
  """Fetch all Nodes in the graph defined by "inputs" and "outputs".

  This method is used to find and then clone Nodes when creating a new
  sub-model from an existing functional model.

  Args:
    inputs: A nested structure of KerasTensor to use as model inputs.
    outputs: A nested structure of KerasTensor to use as model outputs.

  Returns:
    A list of Nodes that are connected to the inputs and outputs.

  Raises:
    ValueError: when inputs and outputs are disconnected or in case of
      unexpected objects in the inputs/outputs.
  """
  # We walk the graph bottom up, starting from output nodes, and keep tracing
  # the upstream node, until we find all the inputs nodes. We don't use top
  # down search here since we don't know whether a certain node is in the graph
  # between inputs and outputs, e.g. a functional graph could have multiple
  # outputs, and the user could choose a subset of them to build the model.
  # The bottom up approach will ensure all the nodes we visit are actually
  # in use. If we reach the top and didn't find the nodes in the `inputs`,
  # that's an error, since the user didn't specify the correct inputs.
  start_keras_tensors = tf.nest.flatten(outputs)
  end_keras_tensors = tf.nest.flatten(inputs)

  for t in start_keras_tensors + end_keras_tensors:
    if not node_module.is_keras_tensor(t):
      raise ValueError(_KERAS_TENSOR_TYPE_CHECK_ERROR_MSG.format(t))
  end_ids = set([id(kt) for kt in end_keras_tensors])
  # Track all the end tensors we found so far, if we didn't reach all the
  # user-specified keras inputs after we finish the search, then that's an
  # error since the inputs are disconnected from the outputs.
  end_ids_found = set()

  nodes_to_visit = []
  nodes_in_graph = []
  node_id_visited = set()
  for t in start_keras_tensors:
    nodes_to_visit.append(t.node)

  while nodes_to_visit:
    node = nodes_to_visit.pop(0)
    if id(node) in node_id_visited:
      continue
    node_id_visited.add(id(node))
    nodes_in_graph.append(node)
    # Any input keras_tensor that produce the current node.
    for kt in node.keras_inputs:
      if id(kt) in end_ids:
        # We found the inputs of the model, stop tracing upstream nodes
        end_ids_found.add(id(kt))
        continue

      inbound_node = kt.node
      # In case this is the tf.keras.Input node, we have reached the end of the
      # tracing of upstream nodes. Any further tracing will just be an
      # infinite loop. we should raise an error here since we didn't find the
      # input in the user-specified inputs.
      if inbound_node.is_input:
        raise ValueError('Found input tensor cannot be reached given provided '
                         'output tensors. Please make sure the tensor {} is '
                         'included in the model inputs when building '
                         'functional model.'.format(kt))
      nodes_to_visit.append(inbound_node)

  # Do a final check and make sure we have reached all the user-specified inputs
  if end_ids != end_ids_found:
    unvisited_inputs = [kt for kt in end_keras_tensors
                        if id(kt) not in end_ids_found]
    raise ValueError('Found unvisited input tensors that are disconnected from '
                     'the outputs: {}'.format(unvisited_inputs))
  return nodes_in_graph


def clone_graph_nodes(inputs, outputs):
  """Clone the `Node` between the inputs and output tensors.

  This function is used to create a new functional model from any intermediate
  keras tensors. The clone of the nodes mimic the behavior of reconstructing the
  functional graph network by re-executing all the __call__ methods. The cloned
  nodes will be appended to the layers.

  Note that a new tf.keras.Inputs will be created for any items in the `inputs`

  Args:
    inputs: A nested structure of keras_tensors.
    outputs: A nested structure of keras_tensors.

  Returns:
    A pair of inputs and outputs, with cloned keras_tensors. They can be used to
    create a new functional model.
  """
  nodes_to_clone = find_nodes_by_inputs_and_outputs(inputs, outputs)
  cloned_inputs = []
  cloned_outputs = []
  # We not only need to create copies of Nodes (mimic the calls), also need to
  # clone keras_tensors to avoid the override of _keras_history attached on the
  # keras_tensor. The following dict is used to track any keras tensor we cloned
  # The key is the string ID of the original keras tensor, and value is the
  # cloned keras_tensor instance.
  kt_id_mapping = {}

  for kt_input in tf.nest.flatten(inputs):
    if kt_input.node.is_input:
      # For any existing keras_tensor from tf.keras.Input, we leave them as is.
      cloned_inputs.append(kt_input)
      kt_id_mapping[id(kt_input)] = kt_input
    else:
      # We need to create a new tf.keras.Input for any intermediate keras_tensor
      cpy = _clone_keras_tensor(kt_input)
      cloned_input = input_layer_module.Input(tensor=cpy)
      cloned_inputs.append(cloned_input)
      kt_id_mapping[id(kt_input)] = cloned_input
  cloned_inputs = tf.nest.pack_sequence_as(inputs, cloned_inputs)

  for kt_output in tf.nest.flatten(outputs):
    cpy = _clone_keras_tensor(kt_output)
    # We reuse the _keras_history here, which contains the old information. It
    # is used in the Node constructor to check if the tensor "is_keras_tensor()"
    # The history will be override by the Node constructor anyway for the
    # corresponding layer output anyway.
    cpy._keras_history = kt_output._keras_history  # pylint: disable=protected-access
    cloned_outputs.append(cpy)
    kt_id_mapping[id(kt_output)] = cpy
  cloned_outputs = tf.nest.pack_sequence_as(outputs, cloned_outputs)

  for node in nodes_to_clone:
    # Clone any keras_tensors to avoid override of _keras_history
    # Or reuse an existing keras_tensor if it has already been cloned.
    output_copy = clone_keras_tensors(node.output_tensors, kt_id_mapping)
    call_args_copy = clone_keras_tensors(node.call_args, kt_id_mapping)
    call_kwargs_copy = clone_keras_tensors(node.call_kwargs, kt_id_mapping)
    # Creating new nodes based on the existing node information.
    # Node wires itself to inbound and outbound layers.
    # The Node constructor actually updates this layer's self._inbound_nodes,
    # sets _keras_history on the outputs, and adds itself to the
    # `_outbound_nodes` of the layers that produced the inputs to this
    # layer call.
    node_module.Node(node.layer,
                     call_args=call_args_copy,
                     call_kwargs=call_kwargs_copy,
                     outputs=output_copy)
  return cloned_inputs, cloned_outputs


def clone_keras_tensors(args, keras_tensor_mapping):
  """Clone the keras tensors from the inputs.

  For any KerasTensor instance in the `args`, a new copy of KerasTensor will
  be created if it has not been cloned yet (by checking the
  `keras_tensor_mapping`). For any other types, the instance will be unchanged.
  This function is useful for cloning the Nodes since KerasTensor can't be
  reused across the models.

  Args:
    args: A nested structure of objects, which could contain KerasTensor.
    keras_tensor_mapping: A dict contains the ID of original KerasTensor, and
      the cloned KerasTensor instance. The dict will be updated with newly
      copied KerasTensor instances within this method.
  Returns:
    Same structure as inputs, with KerasTensor cloned.
  """
  result = []
  for obj in tf.nest.flatten(args):
    if node_module.is_keras_tensor(obj):
      if id(obj) in keras_tensor_mapping:
        cpy = keras_tensor_mapping[id(obj)]
      else:
        # Create copy of keras_tensor if we haven't done it before
        cpy = _clone_keras_tensor(obj)
        cpy._keras_history = obj._keras_history  # pylint: disable=protected-access
        keras_tensor_mapping[id(obj)] = cpy
      result.append(cpy)
    else:
      result.append(obj)
  return tf.nest.pack_sequence_as(args, result)


def _clone_keras_tensor(kt):
  """Create an identical keras_tensor based on the input.

  We use keras_tensor_to_placeholder and keras_tensor_from_tensor to make sure
  inferred shape are not lost during the copy.

  Args:
    kt: the input KerasTensor.

  Returns:
    An identical copy of the input KerasTensor.
  """
  # Create a scratch graph since we don't intend to use the placeholders.
  with backend._scratch_graph() as scratch_graph:  # pylint: disable=protected-access
    with scratch_graph.as_default():
      placeholder = keras_tensor.keras_tensor_to_placeholder(kt)
      return keras_tensor.keras_tensor_from_tensor(placeholder)
