在win7 64位,Anaconda安装的Python3.6.1下安装的TensorFlow与Keras,Keras的backend为TensorFlow。在运行Mask R-CNN时,在进行调试时想知道PyCharm (Python IDE)底部窗口输出的Loss格式是在哪里定义的,如下图红框中所示:
图1 训练过程的Loss格式化输出
在上图红框中,Loss的输出格式是在哪里定义的呢?有一点是明确的,即上图红框中的内容是在训练的时候输出的。那么先来看一下Mask R-CNN的训练过程。Keras以Numpy数组作为输入数据和标签的数据类型。训练模型一般使用 fit 函数。然而由于Mask R-CNN训练数据巨大,不能一次性全部载入,否则太消耗内存。于是采用生成器的方式一次载入一个batch的数据,而且是在用到这个batch的数据才开始载入的,那么它的训练函数如下:
self.keras_model.fit_generator( train_generator, initial_epoch=self.epoch, epochs=epochs, steps_per_epoch=self.config.STEPS_PER_EPOCH, callbacks=callbacks, validation_data=val_generator, validation_steps=self.config.VALIDATION_STEPS, max_queue_size=100, workers=workers, use_multiprocessing=False, )
这里训练模型的函数相应的为 fit_generator 函数。注意其中的参数callbacks=callbacks,这个参数在输出红框中的内容起到了关键性的作用。下面看一下callbacks的值:
# Callbacks callbacks = [ keras.callbacks.TensorBoard(log_dir=self.log_dir, histogram_freq=0, write_graph=True, write_images=False), keras.callbacks.ModelCheckpoint(self.checkpoint_path, verbose=0, save_weights_only=True), ]
在输出红框中的内容所需的数据均保存在self.log_dir下。然后调试进入self.keras_model.fit_generator函数,进入keras,legacy.interfaces的legacy_support(func)函数,如下所示:
def legacy_support(func): @six.wraps(func) def wrapper(*args, **kwargs): if object_type == 'class': object_name = args[0].__class__.__name__ else: object_name = func.__name__ if preprocessor: args, kwargs, converted = preprocessor(args, kwargs) else: converted = [] if check_positional_args: if len(args) > len(allowed_positional_args) + 1: raise TypeError('`' + object_name + '` can accept only ' + str(len(allowed_positional_args)) + ' positional arguments ' + str(tuple(allowed_positional_args)) + ', but you passed the following ' 'positional arguments: ' + str(list(args[1:]))) for key in value_conversions: if key in kwargs: old_value = kwargs[key] if old_value in value_conversions[key]: kwargs[key] = value_conversions[key][old_value] for old_name, new_name <strong>本文来源gao@daima#com搞(%代@#码@网2</strong>in conversions: if old_name in kwargs: value = kwargs.pop(old_name) if new_name in kwargs: raise_duplicate_arg_error(old_name, new_name) kwargs[new_name] = value converted.append((new_name, old_name)) if converted: signature = '`' + object_name + '(' for i, value in enumerate(args[1:]): if isinstance(value, six.string_types): signature += '"' + value + '"' else: if isinstance(value, np.ndarray): str_val = 'array' else: str_val = str(value) if len(str_val) > 10: str_val = str_val[:10] + '...' signature += str_val if i < len(args[1:]) - 1 or kwargs: signature += ', ' for i, (name, value) in enumerate(kwargs.items()): signature += name + '=' if isinstance(value, six.string_types): signature += '"' + value + '"' else: if isinstance(value, np.ndarray): str_val = 'array' else: str_val = str(value) if len(str_val) > 10: str_val = str_val[:10] + '...' signature += str_val if i < len(kwargs) - 1: signature += ', ' signature += ')`' warnings.warn('Update your `' + object_name + '` call to the Keras 2 API: ' + signature, stacklevel=2) return func(*args, **kwargs) wrapper._original_function = func return wrapper return legacy_support