网络编程
位置:首页>> 网络编程>> Python编程>> 基于Keras的格式化输出Loss实现方式

基于Keras的格式化输出Loss实现方式

作者:鹊踏枝-码农  发布时间:2021-10-20 20:44:00 

标签:Keras,格式化,输出,Loss

在win7 64位,Anaconda安装的Python3.6.1下安装的TensorFlow与Keras,Keras的backend为TensorFlow。在运行Mask R-CNN时,在进行调试时想知道PyCharm (Python IDE)底部窗口输出的Loss格式是在哪里定义的,如下图红框中所示:

基于Keras的格式化输出Loss实现方式

图1 训练过程的Loss格式化输出

在上图红框中,Loss的输出格式是在哪里定义的呢?有一点是明确的,即上图红框中的内容是在训练的时候输出的。那么先来看一下Mask R-CNN的训练过程。Keras以Numpy数组作为输入数据和标签的数据类型。训练模型一般使用 fit 函数。然而由于Mask R-CNN训练数据巨大,不能一次性全部载入,否则太消耗内存。于是采用生成器的方式一次载入一个batch的数据,而且是在用到这个batch的数据才开始载入的,那么它的训练函数如下:


self.keras_model.fit_generator(
  train_generator,
  initial_epoch=self.epoch,
  epochs=epochs,
  steps_per_epoch=self.config.STEPS_PER_EPOCH,
  callbacks=callbacks,
  validation_data=val_generator,
  validation_steps=self.config.VALIDATION_STEPS,
  max_queue_size=100,
  workers=workers,
  use_multiprocessing=False,
 )

这里训练模型的函数相应的为 fit_generator 函数。注意其中的参数callbacks=callbacks,这个参数在输出红框中的内容起到了关键性的作用。下面看一下callbacks的值:


# Callbacks
 callbacks = [
  keras.callbacks.TensorBoard(log_dir=self.log_dir,
         histogram_freq=0, write_graph=True, write_images=False),
  keras.callbacks.ModelCheckpoint(self.checkpoint_path,
          verbose=0, save_weights_only=True),
 ]

在输出红框中的内容所需的数据均保存在self.log_dir下。然后调试进入self.keras_model.fit_generator函数,进入keras,legacy.interfaces的legacy_support(func)函数,如下所示:


def legacy_support(func):
 @six.wraps(func)
 def wrapper(*args, **kwargs):
  if object_type == 'class':
   object_name = args[0].__class__.__name__
  else:
   object_name = func.__name__
  if preprocessor:
   args, kwargs, converted = preprocessor(args, kwargs)
  else:
   converted = []
  if check_positional_args:
   if len(args) > len(allowed_positional_args) + 1:
    raise TypeError('`' + object_name +
        '` can accept only ' +
        str(len(allowed_positional_args)) +
        ' positional arguments ' +
        str(tuple(allowed_positional_args)) +
        ', but you passed the following '
        'positional arguments: ' +
        str(list(args[1:])))
  for key in value_conversions:
   if key in kwargs:
    old_value = kwargs[key]
    if old_value in value_conversions[key]:
     kwargs[key] = value_conversions[key][old_value]
  for old_name, new_name in conversions:
   if old_name in kwargs:
    value = kwargs.pop(old_name)
    if new_name in kwargs:
     raise_duplicate_arg_error(old_name, new_name)
    kwargs[new_name] = value
    converted.append((new_name, old_name))
  if converted:
   signature = '`' + object_name + '('
   for i, value in enumerate(args[1:]):
    if isinstance(value, six.string_types):
     signature += '"' + value + '"'
    else:
     if isinstance(value, np.ndarray):
      str_val = 'array'
     else:
      str_val = str(value)
     if len(str_val) > 10:
      str_val = str_val[:10] + '...'
     signature += str_val
    if i < len(args[1:]) - 1 or kwargs:
     signature += ', '
   for i, (name, value) in enumerate(kwargs.items()):
    signature += name + '='
    if isinstance(value, six.string_types):
     signature += '"' + value + '"'
    else:
     if isinstance(value, np.ndarray):
      str_val = 'array'
     else:
      str_val = str(value)
     if len(str_val) > 10:
      str_val = str_val[:10] + '...'
     signature += str_val
    if i < len(kwargs) - 1:
     signature += ', '
   signature += ')`'
   warnings.warn('Update your `' + object_name +
       '` call to the Keras 2 API: ' + signature, stacklevel=2)
  return func(*args, **kwargs)
 wrapper._original_function = func
 return wrapper
return legacy_support

在上述代码的倒数第4行的return func(*args, **kwargs)处返回func,func为fit_generator函数,现调试进入fit_generator函数,该函数定义在keras.engine.training模块内的fit_generator函数,调试进入函数callbacks.on_epoch_begin(epoch),如下所示:


# Construct epoch logs.
  epoch_logs = {}
  while epoch < epochs:
   for m in self.stateful_metric_functions:
    m.reset_states()
   callbacks.on_epoch_begin(epoch)

调试进入到callbacks.on_epoch_begin(epoch)函数,进入on_epoch_begin函数,如下所示:


def on_epoch_begin(self, epoch, logs=None):
 """Called at the start of an epoch.
 # Arguments
  epoch: integer, index of epoch.
  logs: dictionary of logs.
 """
 logs = logs or {}
 for callback in self.callbacks:
  callback.on_epoch_begin(epoch, logs)
 self._delta_t_batch = 0.
 self._delta_ts_batch_begin = deque([], maxlen=self.queue_length)
 self._delta_ts_batch_end = deque([], maxlen=self.queue_length)

在上述函数on_epoch_begin中调试进入callback.on_epoch_begin(epoch, logs)函数,转到类ProgbarLogger(Callback)中定义的on_epoch_begin函数,如下所示:


class ProgbarLogger(Callback):
"""Callback that prints metrics to stdout.
# Arguments
 count_mode: One of "steps" or "samples".
  Whether the progress bar should
  count samples seen or steps (batches) seen.
 stateful_metrics: Iterable of string names of metrics that
  should *not* be averaged over an epoch.
  Metrics in this list will be logged as-is.
  All others will be averaged over time (e.g. loss, etc).
# Raises
 ValueError: In case of invalid `count_mode`.
"""

def __init__(self, count_mode='samples',
    stateful_metrics=None):
 super(ProgbarLogger, self).__init__()
 if count_mode == 'samples':
  self.use_steps = False
 elif count_mode == 'steps':
  self.use_steps = True
 else:
  raise ValueError('Unknown `count_mode`: ' + str(count_mode))
 if stateful_metrics:
  self.stateful_metrics = set(stateful_metrics)
 else:
  self.stateful_metrics = set()

def on_train_begin(self, logs=None):
 self.verbose = self.params['verbose']
 self.epochs = self.params['epochs']

def on_epoch_begin(self, epoch, logs=None):
 if self.verbose:
  print('Epoch %d/%d' % (epoch + 1, self.epochs))
  if self.use_steps:
   target = self.params['steps']
  else:
   target = self.params['samples']
  self.target = target
  self.progbar = Progbar(target=self.target,
        verbose=self.verbose,
        stateful_metrics=self.stateful_metrics)
 self.seen = 0

在上述代码的

print('Epoch %d/%d' % (epoch + 1, self.epochs))

输出

Epoch 1/40(如红框中所示内容的第一行)。

然后返回到keras.engine.training模块内的fit_generator函数,执行到self.train_on_batch函数,如下所示:


outs = self.train_on_batch(x, y,
    sample_weight=sample_weight,
    class_weight=class_weight)

if not isinstance(outs, list):
     outs = [outs]
    for l, o in zip(out_labels, outs):
     batch_logs[l] = o

callbacks.on_batch_end(batch_index, batch_logs)

batch_index += 1
    steps_done += 1

调试进入上述代码中的callbacks.on_batch_end(batch_index, batch_logs)函数,进入到on_batch_end函数后,该函数的定义如下所示:


def on_batch_end(self, batch, logs=None):
 """Called at the end of a batch.
 # Arguments
  batch: integer, index of batch within the current epoch.
  logs: dictionary of logs.
 """
 logs = logs or {}
 if not hasattr(self, '_t_enter_batch'):
  self._t_enter_batch = time.time()
 self._delta_t_batch = time.time() - self._t_enter_batch
 t_before_callbacks = time.time()
 for callback in self.callbacks:
  callback.on_batch_end(batch, logs)
 self._delta_ts_batch_end.append(time.time() - t_before_callbacks)
 delta_t_median = np.median(self._delta_ts_batch_end)
 if (self._delta_t_batch > 0. and
  (delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1)):
  warnings.warn('Method on_batch_end() is slow compared '
      'to the batch update (%f). Check your callbacks.'
      % delta_t_median)

接着继续调试进入上述代码中的callback.on_batch_end(batch, logs)函数,进入到在类中ProgbarLogger(Callback)定义的on_batch_end函数,如下所示:


def on_batch_end(self, batch, logs=None):
 logs = logs or {}
 batch_size = logs.get('size', 0)
 if self.use_steps:
  self.seen += 1
 else:
  self.seen += batch_size

for k in self.params['metrics']:
  if k in logs:
   self.log_values.append((k, logs[k]))

# Skip progbar update for the last batch;
 # will be handled by on_epoch_end.
 if self.verbose and self.seen < self.target:
  self.progbar.update(self.seen, self.log_values)

然后执行到上述代码的最后一行self.progbar.update(self.seen, self.log_values),调试进入update函数,该函数定义在模块keras.utils.generic_utils中的类Progbar(object)定义的函数。类的定义及方法如下所示:


class Progbar(object):
"""Displays a progress bar.
# Arguments
 target: Total number of steps expected, None if unknown.
 width: Progress bar width on screen.
 verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
 stateful_metrics: Iterable of string names of metrics that
  should *not* be averaged over time. Metrics in this list
  will be displayed as-is. All others will be averaged
  by the progbar before display.
 interval: Minimum visual progress update interval (in seconds).
"""

def __init__(self, target, width=30, verbose=1, interval=0.05,
    stateful_metrics=None):
 self.target = target
 self.width = width
 self.verbose = verbose
 self.interval = interval
 if stateful_metrics:
  self.stateful_metrics = set(stateful_metrics)
 else:
  self.stateful_metrics = set()

self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
        sys.stdout.isatty()) or
        'ipykernel' in sys.modules)
 self._total_width = 0
 self._seen_so_far = 0
 self._values = collections.OrderedDict()
 self._start = time.time()
 self._last_update = 0

def update(self, current, values=None):
 """Updates the progress bar.
 # Arguments
  current: Index of current step.
  values: List of tuples:
   `(name, value_for_last_step)`.
   If `name` is in `stateful_metrics`,
   `value_for_last_step` will be displayed as-is.
   Else, an average of the metric over time will be displayed.
 """
 values = values or []
 for k, v in values:
  if k not in self.stateful_metrics:
   if k not in self._values:
    self._values[k] = [v * (current - self._seen_so_far),
         current - self._seen_so_far]
   else:
    self._values[k][0] += v * (current - self._seen_so_far)
    self._values[k][1] += (current - self._seen_so_far)
  else:
   # Stateful metrics output a numeric value. This representation
   # means "take an average from a single value" but keeps the
   # numeric formatting.
   self._values[k] = [v, 1]
 self._seen_so_far = current

now = time.time()
 info = ' - %.0fs' % (now - self._start)
 if self.verbose == 1:
  if (now - self._last_update < self.interval and
    self.target is not None and current < self.target):
   return

prev_total_width = self._total_width
  if self._dynamic_display:
   sys.stdout.write('\b' * prev_total_width)
   sys.stdout.write('\r')
  else:
   sys.stdout.write('\n')

if self.target is not None:
   numdigits = int(np.floor(np.log10(self.target))) + 1
   barstr = '%%%dd/%d [' % (numdigits, self.target)
   bar = barstr % current
   prog = float(current) / self.target
   prog_width = int(self.width * prog)
   if prog_width > 0:
    bar += ('=' * (prog_width - 1))
    if current < self.target:
     bar += '>'
    else:
     bar += '='
   bar += ('.' * (self.width - prog_width))
   bar += ']'
  else:
   bar = '%7d/Unknown' % current

self._total_width = len(bar)
  sys.stdout.write(bar)

if current:
   time_per_unit = (now - self._start) / current
  else:
   time_per_unit = 0
  if self.target is not None and current < self.target:
   eta = time_per_unit * (self.target - current)
   if eta > 3600:
    eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // 60, eta % 60)
   elif eta > 60:
    eta_format = '%d:%02d' % (eta // 60, eta % 60)
   else:
    eta_format = '%ds' % eta

info = ' - ETA: %s' % eta_format
  else:
   if time_per_unit >= 1:
    info += ' %.0fs/step' % time_per_unit
   elif time_per_unit >= 1e-3:
    info += ' %.0fms/step' % (time_per_unit * 1e3)
   else:
    info += ' %.0fus/step' % (time_per_unit * 1e6)

for k in self._values:
   info += ' - %s:' % k
   if isinstance(self._values[k], list):
    avg = np.mean(
     self._values[k][0] / max(1, self._values[k][1]))
    if abs(avg) > 1e-3:
     info += ' %.4f' % avg
    else:
     info += ' %.4e' % avg
   else:
    info += ' %s' % self._values[k]

self._total_width += len(info)
  if prev_total_width > self._total_width:
   info += (' ' * (prev_total_width - self._total_width))

if self.target is not None and current >= self.target:
   info += '\n'

sys.stdout.write(info)
  sys.stdout.flush()

elif self.verbose == 2:
  if self.target is None or current >= self.target:
   for k in self._values:
    info += ' - %s:' % k
    avg = np.mean(
     self._values[k][0] / max(1, self._values[k][1]))
    if avg > 1e-3:
     info += ' %.4f' % avg
    else:
     info += ' %.4e' % avg
   info += '\n'

sys.stdout.write(info)
   sys.stdout.flush()

self._last_update = now

def add(self, n, values=None):
 self.update(self._seen_so_far + n, values)

重点是上述代码中的update(self, current, values=None)函数,在该函数内设置断点,即可调入该函数。下面重点分析上述代码中的几个输出条目:

1. sys.stdout.write('\n') #换行

2. sys.stdout.write('bar') #输出 [..................],其中bar= [..................];

3. sys.stdout.write(info) #输出loss格式,其中info='- ETA:...';

4. sys.stdout.flush() #刷新缓存,立即得到输出。

通过对Mask R-CNN代码的调试分析可知,图1中的红框中的训练过程中的Loss格式化输出是由built-in模块实现的。若想得到类似的格式化输出,关键在self.keras_model.fit_generator函数中传入callbacks参数和callbacks中内容的定义。

来源:https://blog.csdn.net/u011501388/article/details/81088690

0
投稿

猜你喜欢

  • 本文实例讲述了正则表达式验证IPV4地址功能。分享给大家供大家参考,具体如下:IPV4地址由4个组数字组成,每组数字之间以.分隔,每组数字的
  • 本文实例讲述了Python快速排序算法。分享给大家供大家参考,具体如下:快速排序的时间复杂度是O(NlogN)算法描述:① 先从序列中取出一
  • 指令和程序计算机的硬件系统通常由五大部件构成,包括:运算器、控制器、存储器、输入设备和输出设备。其中,运算器和控制器放在一起就是我们通常所说
  • <?php function genpage(&$sql,$page_size=10) { global $pages,$su
  • 1、需求当工作在UNIX Shell下时,我们想使用常见的通配符模式(即:.py,Dat[0-9].csv等)来对文本做匹配。2、解决方案f
  • 写在前面:前一段时间 kejun 给我们培训JavaScript的时候,在幻灯片上推荐了很多特别经典的文章,其中就有这一篇。读过之后感觉很不
  • Hihi, 大家好~ 最近有不少人都提及了网页上该如何选择字体的问题。问题虽然小,但是却是前端开发中的基本,因为目前的网页,还是以文字信息
  • 1 蚂蚁森林简介蚂蚁森林是一项旨在带动公众低碳减排的公益项目,每个人的低碳行为在蚂蚁森林里可计为"绿色能量"。"
  • 百度AI功能还是很强大的,百度AI开放平台真的是测试接口的天堂,免费接口很多,当然有量的限制,但个人使用是完全够用的,什么人脸识别、MQTT
  • 如何用ASP发送带附件的邮件?请问如何用CDONTS组件发送带附件的邮件?    见下列代码:<%&nb
  • 前言:在网络时代,图片已经成为了我们生活中不可或缺的一部分。随着各种社交媒体的兴起,我们可以在网上看到越来越多的图片,但是如何从这些图片中获
  • 安装pip install requests发送网络请求import requestsr=requests.get('http://
  • XMLHTTP对象及其方法------------------MSXML中提供了Microsoft.XMLHTTP对象,能够完成从数据包到R
  • 今天在 ajaxian 上看到一篇文章,名为 Five Ajax Anti-pattern ,觉得讲得比较有道理,现粗略翻译一下,加一些自己
  • help函数是python的一个内置函数,在python基础知识中介绍过什么是内置函数,它是python自带的函数,任何时候都可以被使。he
  • 看到sam关于max-height的文章,觉得按捺不住了。sam注重于样式表的写法,过多的要求div+css的布局,sam可是追求艺术的人哦
  • OS 模块在讲解包模块时我们提到通过 sys 模块进行查看全局包路径查看于注册,今天我们尝试了解下OS模块,这个模块主要
  • 本文实例讲述了PHP中PDO事务处理操作。分享给大家供大家参考,具体如下:概要:将多条sql操作(增删改)作为一个操作单元,要么都成功,要么
  • 如果你的PHP网站换了空间,必定要对Mysql数据库进行转移,一般的转移的方法,是备份再还原,有点繁琐,而且由于数据库版本的不一样会导致数据
  • 1.介绍在 Golang 语言项目开发中,经常会遇到数据排序问题。Golang 语言标准库 sort 包,为我们提供了数据排序的功能,我们可
手机版 网络编程 asp之家 www.aspxhome.com