CNN中的FLOPs理解
FLOPS
注意全大写,是floating point operations per second的缩写,意指每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。
FLOPs
注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。本问题针对模型,应指的是FLOPs。以下答案不考虑activation function的运算。
卷积层FLOPs
(K × K × Ci + K × K × Ci - 1 ) × H × W × Co = ( 2 ( K × K × Ci) - 1 ) × H × W × Co
Ci=input channel, k=kernel size, HW=output feature map size, Co=output channel. 2是因为一个MAC算2个operations。不考虑bias时有-1,有bias时没有-1。上面针对一个input feature map,没考虑batch size。理解上面这个公式分两步,括号内是第一步,计算出output feature map的一个pixel,然后再乘以H W Co,拓展到整个output feature map。括号内的部分又可以分为两步,第一项是乘法运算数,第二项是加法运算数,因为n个数相加,要加n-1次,所以不考虑bias,有一个-1,如果考虑bias,刚好中和掉
FC层
(I + I - 1)× O = (2 × I - 1 )× O
I=input neuron numbers, O=output neuron numbers. -1,有bias时没有-1。分析同理,括号内是一个输出神经元的计算量,拓展到O了输出神经元。
推荐两个神器(pytorch)
torchstat 、 torchsummaryX可以用来计算pytorch构建的网络的参数,空间大小,MAdd,FLOPs等指标,简单好用
参考文献
https://www.zhihu.com/question/65305385/answer/451060549
torchstat:https://github.com/Swall0w/torchstat
torchsummaryX:https://github.com/nmhkahn/torchsummaryX