有过关于 Tensorflow 的交叉熵损失函数的记录 - [Tensorflow[旧版本] - Cross Entropy Loss](https://www.aiuai.cn/aifarm198.html),最近发现在新版本的 Tensorflow(tensorflow-gpu 1.11.0) 的实现和文档中已经有了较大的改变,故对此再简单汇总.

tf.losses

Reduction:

loss 函数的类型(Types of loss reduction).

NONE: Un-reduced weighted losses with the same shape as input.

SUM: Scalar sum of weighted losses.

MEAN: Scalar SUM divided by sum of weights.

SUM_OVER_BATCH_SIZE: Scalar SUM divided by number of elements in losses.

SUM_OVER_NONZERO_WEIGHTS: Scalar SUM divided by number of non-zero weights.

SUM_BY_NONZERO_WEIGHTS: Same as SUM_OVER_NONZERO_WEIGHTS.

1. absolute_difference

计算绝对值差值损失函数.

tf.losses.absolute_difference(
    labels,
    predictions,
    weights=1.0,
    scope=None,
    loss_collection=tf.GraphKeys.LOSSES,
    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
)

labels - groundtruth tensor.

predictions - 预测的输出 tensor.

2. compute_weighted_loss

计算加权的损失函数.

tf.losses.compute_weighted_loss(
    losses,
    weights=1.0,
    scope=None,
    loss_collection=tf.GraphKeys.LOSSES,
    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
)

losses - Tensor,[batch_size, d1, ..., dN].

3. hinge_loss

tf.losses.hinge_loss(
    labels,
    logits,
    weights=1.0,
    scope=None,
    loss_collection=tf.GraphKeys.LOSSES,
    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
)

labels - groundtruth tensor. tensor 的值期望为 0.0 或 1.0. 计算时,损失函数内部将 {0, 1} 转化到 {-1. 1}.

logits - float tensor. 假设 logits 是 unbounded 和 0-centered. 大于 0 的值作为 positive,小于 0 的值作为 negative,进行二值预测.

4. log_loss

tf.losses.log_loss(
    labels,
    predictions,
    weights=1.0,
    epsilon=1e-07,
    scope=None,
    loss_collection=tf.GraphKeys.LOSSES,
    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
)

5. mean_pairwise_squared_error

pairwise-errors-squared loss.

tf.losses.mean_pairwise_squared_error(
    labels,
    predictions,
    weights=1.0,
    scope=None,
    loss_collection=tf.GraphKeys.LOSSES
)

mean_pairwise_squared_error 度量的是,predictionslabels 对应元素对之间的差异.

例如,假设 labels=[a, b, c],predictions=[x, y, z],则损失函数计算的是三对差值的相加和:

loss = [ ((a-b) - (x-y)).^2 + ((a-c) - (x-z)).^2 + ((b-c) - (y-z)).^2 ] / 3.

6. mean_squared_error

Sum-of-Squares loss.

tf.losses.mean_squared_error(
    labels,
    predictions,
    weights=1.0,
    scope=None,
    loss_collection=tf.GraphKeys.LOSSES,
    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
)

mean_squared_error 度量了 predictionslabels 对应元素之间的差异.

7. sigmoid_cross_entropy

基于 tf.nn.sigmoid_cross_entropy_with_logits 创建的交叉熵损失函数.

weights - 计算 loss 的权重系数;如果 weights 是一个标量值,则只需将 loss 值乘以 weights ;如果 weights 是 [batch_size] 的 tensor,则 loss 值乘以对应样本的权重.

label_smoothing - 非 0 时,则采用 1/num_classes来平滑类别(LSR):new_multiclass_labels = multiclass_labels * (1 - label_smoothing) + 0.5 * label_smoothing.

multi_class_lables - [batch_size, num_classes],目标为整数值 0 或 1 的 labels.

logits - 网络输出的 logits,[batch_size, num_classes].

8. softmax_cross_entropy

基于 tf.nn.softmax_cross_entropy_with_logits_v2 创建的交叉熵损失函数.

tf.losses.softmax_cross_entropy(
    onehot_labels,
    logits,
    weights=1.0,
    label_smoothing=0,
    scope=None,
    loss_collection=tf.GraphKeys.LOSSES,
    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
)

weights - 计算 loss 的权重系数;如果 weights 是一个标量值,则只需将 loss 值乘以 weights ;如果 weights 是 [batch_size] 的 tensor,则 loss 值乘以对应样本的权重.

label_smoothing - 非 0 时,则采用 1/num_classes来平滑类别(LSR):new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes.

onehot_labelslogits 必须 shape 一致,如[batch_size, num_classes].

onehot_labels - One-hot 编码的标注.

logits - 网络输出的 logits.

9. sparse_softmax_cross_entropy

计算 logits 和 labels 的稀疏 softmax 交叉熵.

适用于离散分类任务中,概率误差的度量. 其中,每个类别间是相互排斥的(每个样本只能有一个类别标签).

例如,每张 CIFAR-10 图片只有而且只能有一个标注的类别标签:狗或猫,而不能同时是狗和猫.

注:

对于该损失函数,给定标签的概率是相互排斥的,也意味着不能进行 soft classes;即:labels 向量必须为 logits 的每一行提供真实类别的单一特定索引.

tf.nn.sparse_softmax_cross_entropy_with_logits(
    _sentinel=None,
    labels=None,
    logits=None,
    name=None
)

Warning:

该损失函数的输入,需要是 unscaled logits,因为,效率起见,其内部对 logits 进行 softmax操作.

不需要再进行 softmax 操作,因为会输出不正确的结果.

logits 的 dtype:float16, float32, float64.

labels 的 dtype:int32, int64.

Last modification:October 25th, 2018 at 02:07 pm