Batch Normalization - tensorflow
tf.nn.moments
tf.nn.moments(
x,
axes,
shift=None,
name=None,
keep_dims=False
)
Args:
x: A Tensor.
axes: Array of ints. Axes along which to compute mean and variance.
shift: Not used in the current implementation
name: Name used to scope the operations that compute the moments.
keep_dims: produce moments with the same dimensionality as the input.
Returns:
Two Tensor objects: mean and variance.
import tensorflow as tf
def bn(x,name="bn"):
axes = [d for d in range(len(x.get_shape()))]
beta = tf.get_variable("beta", shape=[],initializer=tf.constant_initializer(0.0))
gamma= tf.get_variable("gamma",shape=[],initializer=tf.constant_initializer(1.0))
x_mean,x_variance=tf.nn.moments(x,axes)
y=tf.nn.batch_normalization(x,x_mean,x_variance,beta,gamma,1e-10,name)
return y
Reference
https://www.jianshu.com/p/0312e04e4e83
https://blog.csdn.net/vcvycy/article/details/78607351?locationNum=3&fps=1