Caffe源码解读之Blob(二)

机器学习 2017-01-14

  阅读Caffe的源码不必全部看,这样反而达不到目的,理解关键部分即可。现在我们来看Blob结构部分,按头文件和核心文件来阅读,即 /include和 /src部分,推荐使用Atom来看。   在/include/blob.hpp中定义了了Blob的数据结构。

 explicit Blob(const int num, const int channels, const int height,
      const int width);

  num用于存储数据或权值(data)和权值增量(diff),其它三个分别是图片的通道数,长和宽。

 protected:
  shared_ptr<SyncedMemory> data_;
  shared_ptr<SyncedMemory> diff_;
  shared_ptr<SyncedMemory> shape_data_;
  vector<int> shape_;
  int count_;
  int capacity_;

  从C++角度,Blob在blob.hpp中是一个模板类。protected 的成员变量有:data , diff , shape , count , capacity ,其中data 和 diff_ 是共享SyncedMemory 类(在syncedmem的源码中定义)的智能指针,shape是int型的vector,count 和capacity_ 是整型变量。   其成员函数主要有:Reshape 、ReshapeLike、SharedData、 Updata 等。

#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/syncedmem.hpp"

  blob.hpp 包含了caffe.pb.h ,说明caffe protobuf 会向blob传递参数。   caffe.pb.h是google的protocol buffer根据caffe.proto自动生成的,可以到src/caffe/proto/caffe.proto里看下caffe里面用到的各个数据的定义,比如BlobProto,Datum,NetParameter等。使用这个protocol buffer看起来确实方便,一方面可以用文本文件定义结构化的数据类型,另一方面可以生成查询效率更高、占空间更小的二进制文件。   common.hpp主要singleton化Caffe类,并封装了boost和CUDA随机数生成的函数,提供了统一的接口。   syncedmem.hpp主要是分配内存和释放内存的。而class SyncedMemory定义了内存分配管理和CPU与GPU之间同步的函数。   在/src/blob.cpp中

#include "caffe/util/math_functions.hpp"

  math_functions.hpp里封装了很多cblas矩阵运算函数:caffe_cpu_gemm(),caffe_cpu_gemv(),caffe_axpy()等等。具体实现则在math_functions.cpp里面。

  再具体看看刚才提到的Blob的成员变量,其实不多。首先是data_指针,指针类型是shared_ptr,属于boost库的一个智能指针,这一部分主要用来申请内存存储data,data主要是正向传播的时候用的。同理,diff_主要用来存储偏差,update data,shape_data和shape_都是存储Blob的形状。count表示Blob中的元素个数,也就是个数通道数高度*宽度,capacity表示当前的元素个数,因为Blob可能会reshape。

 void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
    const int width) {
  vector<int> shape(4);
  shape[0] = num;
  shape[1] = channels;
  shape[2] = height;
  shape[3] = width;
  Reshape(shape);
}

  下面来说说blob中的函数,可以根据功能来分类。   1.blob中的构造函数开辟一个内存空间来存储数据,Reshape函数在Layer中的reshape或者forward操作中来adjust dimension。同时在改变Blob大小时,内存将会被重新分配如果内存大小不够了,并且额外的内存将不会被释放。对input的blob进行reshape,如果立马调用Net::Backward是会出错的,因为reshape之后,要么Net::forward或者Net::Reshape就会被调用来将新的input shape 传播到高层。   2.同时重载很多个count()函数,主要还是为了统计Blob的容量(volume),或者是某一片(slice),从某个axis到具体某个axis的shape乘积(如 “ inline int count(int start_axis, int end_axis)”)。   3.data_数据操作函数和反向传播导数diff_操作函数有:

inline Dtype data_at(const int n, const int c, const int h, const int w)
inline Dtype diff_at(const int n, const int c, const int h, const int w)
inline Dtype data_at(const vector<int>& index)
inline Dtype diff_at(const vector<int>& index)
inline const shared_ptr<SyncedMemory>& data()
inline const shared_ptr<SyncedMemory>& diff()

  可以在blob.hpp中找到。这一部分函数主要通过给定的位置访问数据,根据位置计算与数据起始的偏差offset,在通过cpu_data*指针获得地址。

  4.将数据序列化,存储到BlobProto,这里说到Proto是谷歌的一个数据序列化的存储格式,可以实现语言、平台无关、可扩展的序列化结构数据格式。

void FromProto(const BlobProto& proto, bool reshape = true);
void ToProto(BlobProto* proto, bool write_diff = false) const;

  5.update函数,该函数用于参数blob的更新(weight,bias 等减去对应的导数)。

  6.其它和运算相关的函数也一一列出

Dtype asum_data() const;//计算data的L1范数(所有元素绝对值之和)
Dtype asum_diff() const;//计算diff的L1范数
Dtype sumsq_data() const;//计算data的L2范数(所有元素平方和)
Dtype sumsq_diff() const;//计算diff的L2范数
void scale_data(Dtype scale_factor);//将data部分乘以一个因子
void scale_diff(Dtype scale_factor);//将diff部分乘一个因子

  关于Blob部分也就这些内容了。


本文由 Tony 创作,采用 知识共享署名 3.0,可自由转载、引用,但需署名作者且注明文章出处。

如果对您有用,您的支持将鼓励我继续创作!