# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Utilities related to layer/model functionality."""

import tensorflow.compat.v2 as tf

import functools
import weakref

from keras.utils import io_utils
import numpy as np
from tensorflow.python.util.tf_export import keras_export


@keras_export('keras.utils.get_source_inputs')
def get_source_inputs(tensor, layer=None, node_index=None):
  """Returns the list of input tensors necessary to compute `tensor`.

  Output will always be a list of tensors
  (potentially with 1 element).

  Args:
      tensor: The tensor to start from.
      layer: Origin layer of the tensor. Will be
          determined via tensor._keras_history if not provided.
      node_index: Origin node index of the tensor.

  Returns:
      List of input tensors.
  """
  if not hasattr(tensor, '_keras_history'):
    return tensor

  if layer is None or node_index:
    layer, node_index, _ = tensor._keras_history
  if not layer._inbound_nodes:
    return [tensor]
  else:
    node = layer._inbound_nodes[node_index]
    if node.is_input:
      # Reached an Input layer, stop recursion.
      return tf.nest.flatten(node.input_tensors)
    else:
      source_tensors = []
      for layer, node_index, _, tensor in node.iterate_inbound():
        previous_sources = get_source_inputs(tensor, layer, node_index)
        # Avoid input redundancy.
        for x in previous_sources:
          if all(x is not t for t in source_tensors):
            source_tensors.append(x)
      return source_tensors


def validate_string_arg(input_data,
                        allowable_strings,
                        layer_name,
                        arg_name,
                        allow_none=False,
                        allow_callables=False):
  """Validates the correctness of a string-based arg."""
  if allow_none and input_data is None:
    return
  elif allow_callables and callable(input_data):
    return
  elif isinstance(input_data, str) and input_data in allowable_strings:
    return
  else:
    allowed_args = '`None`, ' if allow_none else ''
    allowed_args += 'a `Callable`, ' if allow_callables else ''
    allowed_args += 'or one of the following values: %s' % (allowable_strings,)
    if allow_callables:
      callable_note = (
          f'If restoring a model and `{arg_name}` is a custom callable, '
          'please ensure the callable is registered as a custom object. '
          'See https://www.tensorflow.org/guide/keras/save_and_serialize'
          '#registering_the_custom_object for details. ')
    else:
      callable_note = ''
    raise ValueError(
        f'Unkown value for `{arg_name}` argument of layer {layer_name}. '
        f'{callable_note}Allowed values are: {allowed_args}. Received: '
        f'{input_data}')


def count_params(weights):
  """Count the total number of scalars composing the weights.

  Args:
      weights: An iterable containing the weights on which to compute params

  Returns:
      The total number of scalars composing the weights
  """
  unique_weights = {id(w): w for w in weights}.values()
  # Ignore TrackableWeightHandlers, which will not have a shape defined.
  unique_weights = [w for w in unique_weights if hasattr(w, 'shape')]
  weight_shapes = [w.shape.as_list() for w in unique_weights]
  standardized_weight_shapes = [
      [0 if w_i is None else w_i for w_i in w] for w in weight_shapes
  ]
  return int(sum(np.prod(p) for p in standardized_weight_shapes))


def print_summary(model,
                  line_length=None,
                  positions=None,
                  print_fn=None,
                  expand_nested=False,
                  show_trainable=False):
  """Prints a summary of a model.

  Args:
      model: Keras model instance.
      line_length: Total length of printed lines
          (e.g. set this to adapt the display to different
          terminal window sizes).
      positions: Relative or absolute positions of log elements in each line.
          If not provided, defaults to `[.33, .55, .67, 1.]`.
      print_fn: Print function to use.
          It will be called on each line of the summary.
          You can set it to a custom function
          in order to capture the string summary.
          It defaults to `print` (prints to stdout).
      expand_nested: Whether to expand the nested models.
          If not provided, defaults to `False`.
      show_trainable: Whether to show if a layer is trainable.
          If not provided, defaults to `False`.
  """
  if print_fn is None:
    print_fn = io_utils.print_msg

  if model.__class__.__name__ == 'Sequential':
    sequential_like = True
  elif not model._is_graph_network:
    # We treat subclassed models as a simple sequence of layers, for logging
    # purposes.
    sequential_like = True
  else:
    sequential_like = True
    nodes_by_depth = model._nodes_by_depth.values()
    nodes = []
    for v in nodes_by_depth:
      if (len(v) > 1) or (len(v) == 1 and
                          len(tf.nest.flatten(v[0].keras_inputs)) > 1):
        # if the model has multiple nodes
        # or if the nodes have multiple inbound_layers
        # the model is no longer sequential
        sequential_like = False
        break
      nodes += v
    if sequential_like:
      # search for shared layers
      for layer in model.layers:
        flag = False
        for node in layer._inbound_nodes:
          if node in nodes:
            if flag:
              sequential_like = False
              break
            else:
              flag = True
        if not sequential_like:
          break

  if sequential_like:
    line_length = line_length or 65
    positions = positions or [.45, .85, 1.]
    if positions[-1] <= 1:
      positions = [int(line_length * p) for p in positions]
    # header names for the different log elements
    to_display = ['Layer (type)', 'Output Shape', 'Param #']
  else:
    line_length = line_length or 98
    positions = positions or [.33, .55, .67, 1.]
    if positions[-1] <= 1:
      positions = [int(line_length * p) for p in positions]
    # header names for the different log elements
    to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
    relevant_nodes = []
    for v in model._nodes_by_depth.values():
      relevant_nodes += v

  if show_trainable:
    line_length += 11
    positions.append(line_length)
    to_display.append('Trainable')

  def print_row(fields, positions, nested_level=0):
    left_to_print = [str(x) for x in fields]
    while any(left_to_print):
      line = ''
      for col in range(len(left_to_print)):
        if col > 0:
          start_pos = positions[col - 1]
        else:
          start_pos = 0
        end_pos = positions[col]
        # Leave room for 2 spaces to delineate columns
        # we don't need any if we are printing the last column
        space = 2 if col != len(positions) - 1 else 0
        cutoff = end_pos - start_pos - space
        fit_into_line = left_to_print[col][:cutoff]
        # For nicer formatting we line-break on seeing end of
        # tuple/dict etc.
        line_break_conditions = ('),', '},', '],', "',")
        candidate_cutoffs = [
            fit_into_line.find(x) + len(x)
            for x in line_break_conditions
            if fit_into_line.find(x) >= 0
        ]
        if candidate_cutoffs:
          cutoff = min(candidate_cutoffs)
          fit_into_line = fit_into_line[:cutoff]

        if col == 0:
          line += '|' * nested_level + ' '
        line += fit_into_line
        line += ' ' * space if space else ''
        left_to_print[col] = left_to_print[col][cutoff:]

        # Pad out to the next position
        if nested_level:
          line += ' ' * (positions[col] - len(line) - nested_level)
        else:
          line += ' ' * (positions[col] - len(line))
      line += '|' * nested_level
      print_fn(line)

  print_fn('Model: "{}"'.format(model.name))
  print_fn('_' * line_length)
  print_row(to_display, positions)
  print_fn('=' * line_length)

  def print_layer_summary(layer, nested_level=0):
    """Prints a summary for a single layer.

    Args:
        layer: target layer.
        nested_level: level of nesting of the layer inside its parent layer
          (e.g. 0 for a top-level layer, 1 for a nested layer).
    """
    try:
      output_shape = layer.output_shape
    except AttributeError:
      output_shape = 'multiple'
    except RuntimeError:  # output_shape unknown in Eager mode.
      output_shape = '?'
    name = layer.name
    cls_name = layer.__class__.__name__
    if not layer.built and not getattr(layer, '_is_graph_network', False):
      # If a subclassed model has a layer that is not called in Model.call, the
      # layer will not be built and we cannot call layer.count_params().
      params = '0 (unused)'
    else:
      params = layer.count_params()
    fields = [name + ' (' + cls_name + ')', output_shape, params]

    if show_trainable:
      fields.append('Y' if layer.trainable else 'N')

    print_row(fields, positions, nested_level)

  def print_layer_summary_with_connections(layer, nested_level=0):
    """Prints a summary for a single layer (including topological connections).

    Args:
        layer: target layer.
        nested_level: level of nesting of the layer inside its parent layer
          (e.g. 0 for a top-level layer, 1 for a nested layer).
    """
    try:
      output_shape = layer.output_shape
    except AttributeError:
      output_shape = 'multiple'
    connections = []
    for node in layer._inbound_nodes:
      if relevant_nodes and node not in relevant_nodes:
        # node is not part of the current network
        continue

      for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
        connections.append('{}[{}][{}]'.format(inbound_layer.name, node_index,
                                               tensor_index))

    name = layer.name
    cls_name = layer.__class__.__name__
    fields = [
        name + ' (' + cls_name + ')', output_shape,
        layer.count_params(), connections
    ]

    if show_trainable:
      fields.append('Y' if layer.trainable else 'N')

    print_row(fields, positions, nested_level)

  def print_layer(layer, nested_level=0, is_nested_last=False):
    if sequential_like:
      print_layer_summary(layer, nested_level)
    else:
      print_layer_summary_with_connections(layer, nested_level)

    if expand_nested and hasattr(layer, 'layers') and layer.layers:
      print_fn('|' * (nested_level + 1) + '¯' *
               (line_length - 2 * nested_level - 2) + '|' * (nested_level + 1))

      nested_layer = layer.layers
      is_nested_last = False
      for i in range(len(nested_layer)):
        if i == len(nested_layer) - 1:
          is_nested_last = True
        print_layer(nested_layer[i], nested_level + 1, is_nested_last)

      print_fn('|' * nested_level + '¯' * (line_length - 2 * nested_level) +
               '|' * nested_level)

    if not is_nested_last:
      print_fn('|' * nested_level + ' ' * (line_length - 2 * nested_level) +
               '|' * nested_level)

  layers = model.layers
  for layer in layers:
    print_layer(layer)
  print_fn('=' * line_length)

  if hasattr(model, '_collected_trainable_weights'):
    trainable_count = count_params(model._collected_trainable_weights)
  else:
    trainable_count = count_params(model.trainable_weights)

  non_trainable_count = count_params(model.non_trainable_weights)

  print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count))
  print_fn('Trainable params: {:,}'.format(trainable_count))
  print_fn('Non-trainable params: {:,}'.format(non_trainable_count))
  print_fn('_' * line_length)


def convert_dense_weights_data_format(dense,
                                      previous_feature_map_shape,
                                      target_data_format='channels_first'):
  """Utility useful when changing a convnet's `data_format`.

  When porting the weights of a convnet from one data format to the other,
  if the convnet includes a `Flatten` layer
  (applied to the last convolutional feature map)
  followed by a `Dense` layer, the weights of that `Dense` layer
  should be updated to reflect the new dimension ordering.

  Args:
      dense: The target `Dense` layer.
      previous_feature_map_shape: A shape tuple of 3 integers,
          e.g. `(512, 7, 7)`. The shape of the convolutional
          feature map right before the `Flatten` layer that
          came before the target `Dense` layer.
      target_data_format: One of "channels_last", "channels_first".
          Set it "channels_last"
          if converting a "channels_first" model to "channels_last",
          or reciprocally.
  """
  assert target_data_format in {'channels_last', 'channels_first'}
  kernel, bias = dense.get_weights()
  for i in range(kernel.shape[1]):
    if target_data_format == 'channels_first':
      c, h, w = previous_feature_map_shape
      original_fm_shape = (h, w, c)
      ki = kernel[:, i].reshape(original_fm_shape)
      ki = np.transpose(ki, (2, 0, 1))  # last -> first
    else:
      h, w, c = previous_feature_map_shape
      original_fm_shape = (c, h, w)
      ki = kernel[:, i].reshape(original_fm_shape)
      ki = np.transpose(ki, (1, 2, 0))  # first -> last
    kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),))
  dense.set_weights([kernel, bias])


def is_builtin_layer(layer):
  if not getattr(layer, '_keras_api_names', None):
    return False

  # Subclasses of `Layer` that are not exported inherit the export name
  # of the base layer class.
  return (layer._keras_api_names != ('keras.layers.Layer',) and
          layer._keras_api_names_v1 != ('keras.layers.Layer',))


def cached_per_instance(f):
  """Lightweight decorator for caching lazily constructed properties.

  When to use:
  This decorator provides simple caching with minimal overhead. It is designed
  for properties which are expensive to compute and static over the life of a
  class instance, and provides no mechanism for cache invalidation. Thus it is
  best suited for lazily exposing derived properties of other static data.

  For classes with custom getattr / setattr behavior (such as trackable
  objects), storing cache results as object attributes is not performant.
  Instead, a specialized cache can significantly reduce property lookup
  overhead. (While still allowing the decorated property to be lazily computed.)
  Consider the following class:

  ```
  class MyClass:
    def __setattr__(self, key, value):
      # Some expensive class specific code
      # ...
      # ...

      super(MyClass, self).__setattr__(key, value)

    @property
    def thing(self):
      # `thing` is expensive to compute (and may not even be requested), so we
      # want to lazily compute it and then cache it.
      output = getattr(self, '_thing', None)
      if output is None:
        self._thing = output = compute_thing(self)
      return output
  ```

  It's also worth noting that ANY overriding of __setattr__, even something as
  simple as:
  ```
    def __setattr__(self, key, value):
      super(MyClass, self).__setattr__(key, value)
  ```

  Slows down attribute assignment by nearly 10x.

  By contrast, replacing the definition of `thing` with the following sidesteps
  the expensive __setattr__ altogether:

  '''
  @property
  @tracking.cached_per_instance
  def thing(self):
    # `thing` is expensive to compute (and may not even be requested), so we
    # want to lazily compute it and then cache it.
    return compute_thing(self)
  '''

  Performance:
  The overhead for this decorator is ~0.4 us / call. A much lower overhead
  implementation (~0.085 us / call) can be achieved by using a custom dict type:

  ```
  def dict_based_cache(f):
    class Cache(dict):
      __slots__ = ()
      def __missing__(self, key):
        self[key] = output = f(key)
        return output

    return property(Cache().__getitem__)
  ```

  However, that implementation holds class instances as keys, and as a result
  blocks garbage collection. (And modifying it to use weakref's as keys raises
  the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary
  implementation below turns out to be more prudent.

  Args:
    f: The function to cache.

  Returns:
    f decorated with simple caching behavior.
  """

  cache = weakref.WeakKeyDictionary()

  @functools.wraps(f)
  def wrapped(item):
    output = cache.get(item)
    if output is None:
      cache[item] = output = f(item)
    return output

  wrapped.cache = cache
  return wrapped


def filter_empty_layer_containers(layer_list):
  """Filter out empty Layer-like containers and uniquify."""
  # TODO(b/130381733): Make this an attribute in base_layer.Layer.
  existing = set()
  to_visit = layer_list[::-1]
  while to_visit:
    obj = to_visit.pop()
    if id(obj) in existing:
      continue
    existing.add(id(obj))
    if hasattr(obj, '_is_layer') and not isinstance(obj, type):
      yield obj
    else:
      sub_layers = getattr(obj, 'layers', None) or []

      # Trackable data structures will not show up in ".layers" lists, but
      # the layers they contain will.
      to_visit.extend(sub_layers[::-1])
