Batch Normalization - tensorflow

tf.nn.moments
tf.nn.moments(

x,
axes,
shift=None,
name=None,
keep_dims=False

)

Args:
x: A Tensor.
axes: Array of ints. Axes along which to compute mean and variance.
shift: Not used in the current implementation
name: Name used to scope the operations that compute the moments.
keep_dims: produce moments with the same dimensionality as the input.

Returns:
Two Tensor objects: mean and variance.

import tensorflow as tf

def bn(x,name="bn"):
    axes = [d for d in range(len(x.get_shape()))]
    beta = tf.get_variable("beta", shape=[],initializer=tf.constant_initializer(0.0))
    gamma= tf.get_variable("gamma",shape=[],initializer=tf.constant_initializer(1.0))
    x_mean,x_variance=tf.nn.moments(x,axes)  
    y=tf.nn.batch_normalization(x,x_mean,x_variance,beta,gamma,1e-10,name)
    return y

Reference

https://www.jianshu.com/p/0312e04e4e83
https://blog.csdn.net/vcvycy/article/details/78607351?locationNum=3&fps=1

Last modification:October 9, 2023
您赞赏,我就多写点儿。