在上文对Command Line Interfaces进行了简单的介绍之后,本文将对caffe的Solver相关的代码进行分析。

本文将主要分为四部分的内容:

  • Solver的初始化(Register宏和构造函数)
  • SIGINT和SIGHUP信号的处理
  • Solver::Solve()具体实现
  • SGDSolver::ApplyUpdate具体实现

Solver的初始化(Register宏和构造函数)

shared_ptr<caffe::Solver<float> >
    solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

caffe.cpp中的train函数中通过上面的代码定义了一个指向Solver的shared_ptr。其中主要是通过调用SolverRegistry这个类的静态成员函数CreateSolver得到一个指向Solver的指针来构造shared_ptr类型的solver。而且由于C++多态的特性,尽管solver是一个指向基类Solver类型的指针,通过solver这个智能指针来调用各个成员函数会调用到各个子类(SGDSolver等)的函数。具体的过程如下面的流程图所示:

Create solver

下面我们就来具体看一下SolverRegistry这个类的代码,以便理解是如何通过同一个函数得到不同类型的Solver:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class SolverRegistry {
 public:
  typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
  typedef std::map<string, Creator> CreatorRegistry;
  static CreatorRegistry& Registry() {
    static CreatorRegistry* g_registry_ = new CreatorRegistry();
    return *g_registry_;
  }
  static void AddCreator(const string& type, Creator creator) {
    CreatorRegistry& registry = Registry();
    CHECK_EQ(registry.count(type), 0)
        << "Solver type " << type << " already registered.";
    registry[type] = creator;
  }
  static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
    const string& type = param.type();
    CreatorRegistry& registry = Registry();
    CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
        << " (known types: " << SolverTypeListString() << ")";
    return registry[type](param);
  }
  static vector<string> SolverTypeList() {
    CreatorRegistry& registry = Registry();
    vector<string> solver_types;
    for (typename CreatorRegistry::iterator iter = registry.begin();
         iter != registry.end(); ++iter) {
      solver_types.push_back(iter->first);
    }
    return solver_types;
  }
 private:
  SolverRegistry() {}
  static string SolverTypeListString() {
    vector<string> solver_types = SolverTypeList();
    string solver_types_str;
    for (vector<string>::iterator iter = solver_types.begin();
         iter != solver_types.end(); ++iter) {
      if (iter != solver_types.begin()) {
        solver_types_str += ", ";
      }
      solver_types_str += *iter;
    }
    return solver_types_str;
  }
};

首先需要注意的是这个类的构造函数是private的,也就是用我们没有办法去构造一个这个类型的变量,这个类也没有数据成员,所有的成员函数也都是static的,可以直接调用。

我们首先从 CreateSolver 函数(第15行)入手,这个函数先定义了string类型的变量type,表示Solver的类型(‘SGD’/’Nestrov’等),然后定义了一个key类型为string,value类型为Creator的map:registry,其中Creator是一个函数指针类型,指向的函数的参数为SolverParameter类型,返回类型为Solver*(见第2行和第3行)。如果是一个已经register过的Solver类型,那么registry.count(type)应该为1,然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个creator函数,将creator返回的Solver*返回。

上面的代码中,Registry这个函数(第5行)中定义了一个static的变量g_registry,这个变量是一个指向CreatorRegistry这个map类型的指针,然后直接返回,因为这个变量是static的,所以即使多次调用这个函数,也只会定义一个g_registry,而且在其他地方修改这个map里的内容,是存储在这个map中的。事实上各个Solver的register的过程正是往g_registry指向的那个map里添加以Solver的type为key,对应的Creator函数指针为value的内容。Register的过程如流程图所示:

Register Solver

下面我们具体来看一下Solver的register的过程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
template <typename Dtype>
class SolverRegisterer {
 public:
  SolverRegisterer(const string& type,
      Solver<Dtype>* (*creator)(const SolverParameter&)) {
    // LOG(INFO) << "Registering solver type: " << type;
    SolverRegistry<Dtype>::AddCreator(type, creator);
  }
};
#define REGISTER_SOLVER_CREATOR(type, creator)                                 \
  static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>);    \
  static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>)   \

#define REGISTER_SOLVER_CLASS(type)                                            \
  template <typename Dtype>                                                    \
  Solver<Dtype>* Creator_##type##Solver(                                       \
      const SolverParameter& param)                                            \
  {                                                                            \
    return new type##Solver<Dtype>(param);                                     \
  }                                                                            \
  REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
}
// register SGD Solver
REGISTER_SOLVER_CLASS(SGD);

在sgd_solver.cpp(SGD Solver对应的cpp文件)末尾有上面第24行的代码,使用了REGISTER_SOLVER_CLASS这个宏,这个宏会定义一个名为Creator_SGDSolver的函数,这个函数即为Creator类型的指针指向的函数,在这个函数中调用了SGDSolver的构造函数,并将构造的这个变量得到的指针返回,这也就是Creator类型函数的作用:构造一个对应类型的Solver对象,将其指针返回。然后在这个宏里又调用了REGISTER_SOLVER_CREATOR这个宏,这里分别定义了SolverRegisterer这个模板类的float和double类型的static变量,这会去调用各自的构造函数,而在SolverRegisterer的构造函数中调用了之前提到的SolverRegistry类的AddCreator函数,这个函数就是将刚才定义的Creator_SGDSolver这个函数的指针存到g_registry指向的map里面。类似地,所有的Solver对应的cpp文件的末尾都调用了这个宏来完成注册,在所有的Solver都注册之后,我们就可以通过之前描述的方式,通过g_registry得到对应的Creator函数的指针,并通过调用这个Creator函数来构造对应的Solver。Register和Create对应的流程图如下所示:

SIGINT和SIGHUP信号的处理

Caffe在train或者test的过程中都有可能会遇到系统信号(用户按下ctrl+c或者关掉了控制的terminal),我们可以通过对sigint_effect和sighup_effect来设置遇到系统信号的时候希望进行的处理方式:

caffe train –solver=/path/to/solver.prototxt –sigint_effect=EFFECT –sighup_effect=EFFECT

在caffe.cpp中定义了一个GetRequesedAction函数来将设置的string类型的标志转变为枚举类型的变量:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
caffe::SolverAction::Enum GetRequestedAction(
    const std::string& flag_value) {
  if (flag_value == "stop") {
    return caffe::SolverAction::STOP;
  }
  if (flag_value == "snapshot") {
    return caffe::SolverAction::SNAPSHOT;
  }
  if (flag_value == "none") {
    return caffe::SolverAction::NONE;
  }
  LOG(FATAL) << "Invalid signal effect \""<< flag_value << "\" was specified";
}
// SolverAction::Enum的定义
namespace SolverAction {
  enum Enum {
    NONE = 0,  // Take no special action.
    STOP = 1,  // Stop training. snapshot_after_train controls whether a
               // snapshot is created.
    SNAPSHOT = 2  // Take a snapshot, and keep training.
  };
}

其中SolverAction::Enum的定义在solver.hpp中,这是一个定义为枚举类型的数据类型,只有三个可能的值,分别对应了三种处理系统信号的方式:NONE(忽略信号什么都不做)/STOP(停止训练)/SNAPSHOT(保存当前的训练状态,继续训练)。在caffe.cpp中的train函数里Solver设置如何处理系统信号的代码为:

1
2
3
4
5
caffe::SignalHandler signal_handler(
      GetRequestedAction(FLAGS_sigint_effect),
      GetRequestedAction(FLAGS_sighup_effect));

solver->SetActionFunction(signal_handler.GetActionFunction());

FLAGS_sigint_effect和FLAGS_sighup_effect是通过gflags定义和解析的两个Command Line Interface的输入参数,分别对应遇到sigint和sighup信号的处理方式,如果用户不设定(大部分时候我自己就没设定),sigint的默认值为”stop”,sighup的默认值为”snapshot”。GetRequestedAction函数会将string类型的FLAGS_xx转为SolverAction::Enum类型,并用来定义一个SignalHandler类型的对象signal_handler。我们可以看到这部分代码都依赖于SignalHandler这个类的接口,我们先来看看这个类都做了些什么:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
// header file
class SignalHandler {
 public:
  // Contructor. Specify what action to take when a signal is received.
  SignalHandler(SolverAction::Enum SIGINT_action,
                SolverAction::Enum SIGHUP_action);
  ~SignalHandler();
  ActionCallback GetActionFunction();
 private:
  SolverAction::Enum CheckForSignals() const;
  SolverAction::Enum SIGINT_action_;
  SolverAction::Enum SIGHUP_action_;
};
// source file
SignalHandler::SignalHandler(SolverAction::Enum SIGINT_action,
                             SolverAction::Enum SIGHUP_action):
  SIGINT_action_(SIGINT_action),
  SIGHUP_action_(SIGHUP_action) {
  HookupHandler();
}
void HookupHandler() {
  if (already_hooked_up) {
    LOG(FATAL) << "Tried to hookup signal handlers more than once.";
  }
  already_hooked_up = true;
  struct sigaction sa;
  sa.sa_handler = &handle_signal;
  // ...
}
static volatile sig_atomic_t got_sigint = false;
static volatile sig_atomic_t got_sighup = false;
void handle_signal(int signal) {
  switch (signal) {
  case SIGHUP:
    got_sighup = true;
    break;
  case SIGINT:
    got_sigint = true;
    break;
  }
}
ActionCallback SignalHandler::GetActionFunction() {
  return boost::bind(&SignalHandler::CheckForSignals, this);
}
SolverAction::Enum SignalHandler::CheckForSignals() const {
  if (GotSIGHUP()) {
    return SIGHUP_action_;
  }
  if (GotSIGINT()) {
    return SIGINT_action_;
  }
  return SolverAction::NONE;
}
bool GotSIGINT() {
  bool result = got_sigint;
  got_sigint = false;
  return result;
}
bool GotSIGHUP() {
  bool result = got_sighup;
  got_sighup = false;
  return result;
}
// ActionCallback的含义
typedef boost::function<SolverAction::Enum()> ActionCallback;

SignalHandler这个类有两个数据成员,都是SolverAction::Enum类型的,分别对应sigint和sighup信号,在构造函数中,用解析FLAGS_xx得到的结果分别给两个成员赋值,然后调用了HookupHandler函数,这个函数的主要作用是定义了一个sigaction类型(应该是系统级别的代码)的对象sa,然后通过sa.sa_handler = &handle_signal来设置,当有遇到系统信号时,调用handle_signal函数来处理,而我们可以看到这个函数的处理很简单,就是判断一下当前的信号是什么类型,如果是sigint就将全局的static变量got_sigint变为true,sighup的处理类似。

在根据用户设置(或者默认值)的参数定义了signal_handler之后,solver通过SetActionFunction来设置了如何处理系统信号。这个函数的输入为signal_handler的GetActionFunction的返回值,根据上面的代码我们可以看到,GetActionFunction会返回signal_handler这个对象的CheckForSignals函数的地址(boost::bind的具体使用请参考boost官方文档)。而在Solver的SetActionFunction函数中只是简单的把Solver的一个成员action_request_function_赋值为输入参数的值,以当前的例子来说就是,solver对象的action_request_function_指向了signal_handler对象的CheckForSignals函数的地址。其中的ActionCallback是一个函数指针类型,指向了参数为空,返回值为SolverAction::Enum类型的函数(boost::function具体用法参考官方文档)。

总结起来,我们通过定义一个SignalHandler类型的对象,告知系统在遇到系统信号的时候回调handle_signal函数来改变全局变量got_sigint和got_sighup的值,然后通过Solver的接口设置了其遇到系统函数将调用signal_handler的Check函数,这个函数实际上就是去判断当前是否遇到了系统信号,如果遇到某个类型的信号,就返回我们之前设置的处理方式(SolverAction::Enum类型)。剩余的具体处理再交给Solver的其它函数,后面会具体分析。

Solver::Solve()具体实现

Solve函数实现了具体的网络的优化过程,下面我们来具体分析一下这部分的代码,分析见注释:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
void Solver<Dtype>::Solve(const char* resume_file) {
// 检查当前是否是root_solver(多GPU模式下,只有root_solver才运行这一部分的代码)
  CHECK(Caffe::root_solver());
// 然后输出learning policy(更新学习率的策略)
  LOG(INFO) << "Solving " << net_->name();
  LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
// requested_early_exit_一开始被赋值为false,也就是现在没有要求在优化结束前退出
  requested_early_exit_ = false;
// 判断resume_file这个指针是否NULL,如果不是则需要从resume_file存储的路径里读取之前训练的状态
  if (resume_file) {
    LOG(INFO) << "Restoring previous solver status from " << resume_file;
    Restore(resume_file);
  }
// 然后调用了'Step'函数,这个函数执行了实际的逐步的迭代过程
  Step(param_.max_iter() - iter_);
// 迭代结束或者遇到系统信号提前结束后,判断是否需要在训练结束之后snapshot
// 这个可以在solver.prototxt里设置
  if (param_.snapshot_after_train()
      && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
    Snapshot();
  }
// 如果在Step函数的迭代过程中遇到了系统信号,且我们的处理方式设置为STOP,
// 那么requested_early_exit_会被修改为true,迭代提前结束,输出相关信息
  if (requested_early_exit_) {
    LOG(INFO) << "Optimization stopped early.";
    return;
  }
// 判断是否需要输出最后的loss
  if (param_.display() && iter_ % param_.display() == 0) {
    Dtype loss;
    net_->ForwardPrefilled(&loss);
    LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
  }
// 判断是否需要最后Test
  if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
    TestAll();
  }
  LOG(INFO) << "Optimization Done.";
}

下面继续分析具体的迭代过程发生的Step函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
  vector<Blob<Dtype>*> bottom_vec;
// 设置开始的迭代次数(如果是从之前的snapshot恢复的,那iter_等于snapshot时的迭代次数)和结束的迭代次数
  const int start_iter = iter_;
  const int stop_iter = iter_ + iters;
// 输出的loss为前average_loss次loss的平均值,在solver.prototxt里设置,默认为1,
// losses存储之前的average_loss个loss,smoothed_loss为最后要输出的均值
  int average_loss = this->param_.average_loss();
  vector<Dtype> losses;
  Dtype smoothed_loss = 0;
// 迭代
  while (iter_ < stop_iter) {
  // 清空上一次所有参数的梯度
    net_->ClearParamDiffs();
// 判断是否需要测试
    if (param_.test_interval() && iter_ % param_.test_interval() == 0
        && (iter_ > 0 || param_.test_initialization())
        && Caffe::root_solver()) {
      TestAll();
    // 判断是否需要提前结束迭代
      if (requested_early_exit_) {
        break;
      }
    }
    for (int i = 0; i < callbacks_.size(); ++i) {
      callbacks_[i]->on_start();
    }
    // 判断当前迭代次数是否需要显示loss等信息
    const bool display = param_.display() && iter_ % param_.display() == 0;
    net_->set_debug_info(display && param_.debug_info());
    Dtype loss = 0;
    // iter_size也是在solver.prototxt里设置,实际上的batch_size=iter_size*网络定义里的batch_size,
    // 因此每一次迭代的loss是iter_size次迭代的和,再除以iter_size,这个loss是通过调用Net::ForwardBackward函数得到的
    // 这个设置我的理解是在GPU的显存不够的时候使用,比如我本来想把batch_size设置为128,但是会out_of_memory,
    // 借助这个方法,可以设置batch_size=32,iter_size=4,那实际上每次迭代还是处理了128个数据
    for (int i = 0; i < param_.iter_size(); ++i) {
      loss += net_->ForwardBackward(bottom_vec);
    }
    loss /= param_.iter_size();
    // 计算要输出的smoothed_loss,如果losses里还没有存够average_loss个loss则将当前的loss插入,如果已经存够了,则将之前的替换掉
    if (losses.size() < average_loss) {
      losses.push_back(loss);
      int size = losses.size();
      smoothed_loss = (smoothed_loss * (size - 1) + loss) / size;
    } else {
      int idx = (iter_ - start_iter) % average_loss;
      smoothed_loss += (loss - losses[idx]) / average_loss;
      losses[idx] = loss;
    }
    // 输出当前迭代的信息
    if (display) {
      LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
          << ", loss = " << smoothed_loss;
      const vector<Blob<Dtype>*>& result = net_->output_blobs();
      int score_index = 0;
      for (int j = 0; j < result.size(); ++j) {
        const Dtype* result_vec = result[j]->cpu_data();
        const string& output_name =
            net_->blob_names()[net_->output_blob_indices()[j]];
        const Dtype loss_weight =
            net_->blob_loss_weights()[net_->output_blob_indices()[j]];
        for (int k = 0; k < result[j]->count(); ++k) {
          ostringstream loss_msg_stream;
          if (loss_weight) {
            loss_msg_stream << " (* " << loss_weight
                            << " = " << loss_weight * result_vec[k] << " loss)";
          }
          LOG_IF(INFO, Caffe::root_solver()) << "    Train net output #"
              << score_index++ << ": " << output_name << " = "
              << result_vec[k] << loss_msg_stream.str();
        }
      }
    }
    for (int i = 0; i < callbacks_.size(); ++i) {
      callbacks_[i]->on_gradients_ready();
    }
    // 执行梯度的更新,这个函数在基类Solver中没有实现,会调用每个子类自己的实现,后面具体分析SGDSolver的实现
    ApplyUpdate();
    // 迭代次数加1
    ++iter_;
    // 调用GetRequestedAction,实际是通过action_request_function_函数指针调用之前设置好(通过SetRequestedAction)的
    // signal_handler的CheckForSignals函数,这个函数的作用是
    // 会根据之前是否遇到系统信号以及信号的类型和我们设置(或者默认)的方式返回处理的方式
    SolverAction::Enum request = GetRequestedAction();
    // 判断当前迭代是否需要snapshot,如果request等于SNAPSHOT则也需要
    if ((param_.snapshot()
         && iter_ % param_.snapshot() == 0
         && Caffe::root_solver()) ||
         (request == SolverAction::SNAPSHOT)) {
      Snapshot();
    }
    // 如果request为STOP则修改requested_early_exit_为true,之后就会提前结束迭代
    if (SolverAction::STOP == request) {
      requested_early_exit_ = true;
      break;
    }
  }
}

SGDSolver::ApplyUpdate具体实现

每一组网络中的参数的更新都是在不同类型的Solver自己实现的ApplyUpdate函数中完成的,下面我们就以最常用的SGD为例子来分析这个函数具体的功能:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
template <typename Dtype>
void SGDSolver<Dtype>::ApplyUpdate() {
  CHECK(Caffe::root_solver());
  // GetLearningRate根据设置的lr_policy来计算当前迭代的learning rate的值
  Dtype rate = GetLearningRate();
  // 判断是否需要输出当前的learning rate
  if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
    LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
  }
  // 避免梯度爆炸,如果梯度的二范数超过了某个数值则进行scale操作,将梯度减小
  ClipGradients();
  // 对所有可更新的网络参数进行操作
  for (int param_id = 0; param_id < this->net_->learnable_params().size();
       ++param_id) {
    // 将第param_id个参数的梯度除以iter_size,这一步的作用是保证实际的batch_size=iter_size*设置的batch_size
    Normalize(param_id);
    // 将正则化部分的梯度降入到每个参数的梯度中 
    Regularize(param_id);
    // 计算SGD算法的梯度(momentum等)
    ComputeUpdateValue(param_id, rate);
  }
  // 调用Net::Update更新所有的参数
  this->net_->Update();
}

下面我们继续具体分析一下Normalize/Regularize/ComputeUpdateValue的实现,我们均以CPU的代码为例子,GPU部分的处理原理是一样的:

Normalize

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
template <typename Dtype>
void SGDSolver<Dtype>::Normalize(int param_id) {
  // 如果iter_size的值为1,则不需要任何处理直接return
  if (this->param_.iter_size() == 1) { return; }
  // 通过net_返回所有可以学习的参数,是一个vector<shared_ptr<Blob<Dtype> > >
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  // 要乘以的系数等于1/iter_size
  const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();
  switch (Caffe::mode()) {
  case Caffe::CPU: {
    // caffe_scal在/CAFFE_ROOT/src/caffe/util/math_functions.cpp中
    // 是blas的scale函数的一个封装,第一个参数是数据的个数,第二个参数是乘以的系数,
    // 第三个参数是数据的指针
    caffe_scal(net_params[param_id]->count(), accum_normalization,
        net_params[param_id]->mutable_cpu_diff());
    break;
  }
  case Caffe::GPU: { 
    // GPU代码略
  }
}

Regularize

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
template <typename Dtype>
void SGDSolver<Dtype>::Regularize(int param_id) {
  // 获取所有可以学习的参数的vector
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  // 获取所有的参数对应的weight_decay的vector
  const vector<float>& net_params_weight_decay =
      this->net_->params_weight_decay();
  // 模型整体的weight_decay数值
  Dtype weight_decay = this->param_.weight_decay();
  // 获取正则化的类型:L1 或 L2
  string regularization_type = this->param_.regularization_type();
  // 实际的weight_decay等于整体模型的数值乘以具体每个参数的数值
  Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
  switch (Caffe::mode()) {
  case Caffe::CPU: {
    // 如果weight_decay不为0,则计算
    if (local_decay) {
      if (regularization_type == "L2") {
        // L2的梯度为diff_ = weight_decay*data_ + diff_
        // caffe_axpy的功能是 y = a*x + y
        // 第一个参数是数据的个数,第二个是上式的a,第三个是x的指针,第四个是y的指针
        caffe_axpy(net_params[param_id]->count(),
            local_decay,
            net_params[param_id]->cpu_data(),
            net_params[param_id]->mutable_cpu_diff());
      } else if (regularization_type == "L1") {
        // L1的梯度为diff_ = diff_ + sign(data_)
        // temp_ = sign(data_)
        caffe_cpu_sign(net_params[param_id]->count(),
            net_params[param_id]->cpu_data(),
            temp_[param_id]->mutable_cpu_data());
        // 将temp_加到diff_中 diff_ = weight_decay*temp_ + diff_
        caffe_axpy(net_params[param_id]->count(),
            local_decay,
            temp_[param_id]->cpu_data(),
            net_params[param_id]->mutable_cpu_diff());
      } else {
        LOG(FATAL) << "Unknown regularization type: " << regularization_type;
      }
    }
    break;
  }
// GPU代码略
}

ComputeUpdatedValue

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
  // 获取所有可以更新的参数的vector
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  // 获取所有参数对应的learning_rate的vector
  const vector<float>& net_params_lr = this->net_->params_lr();
  // 获取momentum数值
  Dtype momentum = this->param_.momentum();
  // 实际的learning_rate为全局的learning_rate乘以每个参数对应的learning_rate
  Dtype local_rate = rate * net_params_lr[param_id];
  switch (Caffe::mode()) {
  case Caffe::CPU: {
    // 关于SGD的公式参考caffe官网tutorial的Solver部分
    // history_存储了上一次的梯度,下面这个函数:
    // history_ = learning_rate*diff_ + momentum*history
    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
              net_params[param_id]->cpu_diff(), momentum,
              history_[param_id]->mutable_cpu_data());
    // 把当前的梯度拷贝给参数Blob的diff_
    caffe_copy(net_params[param_id]->count(),
        history_[param_id]->cpu_data(),
        net_params[param_id]->mutable_cpu_diff());
    break;
  }
  case Caffe::GPU: {
    // GPU代码略
  }
}

至此Solver主要的代码都已经分析完了,总结起来主要有:(1)solver_factory的register和create不同类型Solver的机制,(2)通过signal_handler来获取系统信号,并根据用户或默认的设置进行相应的处理,(3)Solver::Solve函数的具体实现的分析,(4)SGDSolver::ApplyUpdate函数的具体实现。前面三个部分都属于基类的,最后一个是SGDSolver这个子类的,如果用户想要实现自己的Solver类,也应该类似地去继承基类,并实现自己的ApplyUpdate函数,在代码的末尾通过register宏完成注册,便可以被成功的调用。

在上文对Google Protocol Buffer进行了简单的介绍之后,本文将对caffe的Command Line Interfaces进行分析。

本文将从一个比较宏观的层面上去了解caffe怎么去完成一些初始化的工作和使用Solver的接口函数,本文将主要分为四部分的内容:

  • Google Flags的使用
  • Register Brew Function的宏的定义和使用
  • train()函数的具体实现
  • SolverParameter的具体解析过程

Google Flags的使用

Caffe官网中可以看到,caffe的Command Line Interfaces一共提供了四个功能:train/test/time/device_query,而Interfaces的输入除了这四种功能还可以输入诸如-solver/-weights/-snapshot/-gpu等参数。这些参数的解析是通过Google Flags这个工具来完成的。

在caffe.cpp(位于/CAFFE_ROOT/tools/caffe.cpp)的开头,我们可以看到很多这样的宏:

DEFINE_string(gpu, "",
    "Optional; run in GPU mode on given device IDs separated by ','."
    "Use '-gpu all' to run on all available GPUs. The effective training "
    "batch size is multiplied by the number of devices.");

这个宏的使用方式为DEFINE_xxx(name, default_value, instruction);,这样就定义了一个xxx类型名为FLAGS_name的标志,如果用户没有在Command Line中提供其值,那么会默认为default_value,instruction是这个标志含义的说明。因此,上面的代码定义了一个string类型的名为FLAGS_gpu的标志,如果在Command Line中用户没有提供值,那么会默认为空字符串,根据说明可以得知这个标志是提供给用户来指定caffe将使用的GPU的。其余的定义也是类似的理解方式就不一一列举了。

解析这些标志的代码在caffe.cpp中的main()中调用了/CAFFE_ROOT/src/common.cpp中的GlobalInit(&argc, &argv)函数:

1
2
3
4
5
6
7
8
void GlobalInit(int* pargc, char*** pargv) {
  // Google flags.
  ::gflags::ParseCommandLineFlags(pargc, pargv, true);
  // Google logging.
  ::google::InitGoogleLogging(*(pargv)[0]);
  // Provide a backtrace on segfault.
  ::google::InstallFailureSignalHandler();
}

第三行的函数就是Google Flags用来解析输入的参数的,前两个参数分别是指向main()的argc和argv的指针,第三个参数为true,表示在解析完所有的标志之后将这些标志从argv中清除,因此在解析完成之后,argc的值为2,argv[0]为main,argv[1]为train/test/time/device_query中的一个。

Register Brew Function的宏的定义和使用

Caffe在Command Line Interfaces中一共提供了4种功能:train/test/time/device_query,分别对应着四个函数,这四个函数的调用是通过一个叫做g_brew_map的全局变量来完成的:

1
2
3
4
// A simple registry for caffe commands.
typedef int (*BrewFunction)();
typedef std::map<caffe::string, BrewFunction> BrewMap;
BrewMap g_brew_map;

g_brew_map是一个key为string类型,value为BrewFunction类型的一个map类型的全局变量,BrewFunction是一个函数指针类型,指向的是参数为空,返回值为int的函数,也就是train/test/time/device_query这四个函数的类型。在train等四个函数实现的后面都紧跟着这样一句宏的调用:RegisterBrewFunction(train);

其中使用的宏的具体定义为:

1
2
3
4
5
6
7
8
9
10
\#define RegisterBrewFunction(func) \
namespace { \
class __Registerer_##func { \
 public: /* NOLINT */ \
  __Registerer_##func() { \
    g_brew_map[#func] = &func; \
  } \
}; \
__Registerer_##func g_registerer_##func; \
}

以train函数为例子,RegisterBrewFunction(train)这个宏的作用是定义了一个名为__Register_train的类,在定义完这个类之后,定义了一个这个类的变量,会调用构造函数,这个类的构造函数在前面提到的g_brew_map中添加了key为”train”,value为指向train函数的指针的一个元素。

然后函数的调用在main()函数中是通过下面的这段代码实现的,在完成初始化(GlobalInit)之后,有这样一句代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// main()中的调用代码
return GetBrewFunction(caffe::string(argv[1]))();
// BrewFunction的具体实现
static BrewFunction GetBrewFunction(const caffe::string& name) {
  if (g_brew_map.count(name)) {
    return g_brew_map[name];
  } else {
    LOG(ERROR) << "Available caffe actions:";
    for (BrewMap::iterator it = g_brew_map.begin();
         it != g_brew_map.end(); ++it) {
      LOG(ERROR) << "\t" << it->first;
    }
    LOG(FATAL) << "Unknown action: " << name;
    return NULL;  // not reachable, just to suppress old compiler warnings.
  }
}

还是以train函数为例子,如果我们在Command Line中输入了caffe train ,经过Google Flags的解析argv[1]=train,因此,在GetBrewFunction中会通过g_brew_map返回一个指向train函数的函数指针,最后在main函数中就通过这个返回的函数指针完成了对train函数的调用。

总结一下:RegisterBrewFunction这个宏在每一个实现主要功能的函数之后将这个函数的名字和其对应的函数指针添加到了g_brew_map中,然后在main函数中,通过GetBrewFunction得到了我们需要调用的那个函数的函数指针,并完成了调用。

train()函数的具体实现

接下来我们仔细地分析一下在train()的具体实现。

首先是这样的一段代码:

1
2
3
4
CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train.";
CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size())
    << "Give a snapshot to resume training or weights to finetune "
    "but not both.";

这段代码的第一行使用了glog的CHECK_GT宏(含义为check greater than),检查FLAGS_solver的size是否大于0,如果小于或等于0则输出提示:”Need a solver definition to train”。FLAGS_solver是最开始通过DEFINE_string定义的标志,如果我们希望训练一个模型,那么自然应该应该提供对应的solver定义文件的路径,这一句话正是在确保我们提供了这样的路径。这样的检查语句在后续的代码中会经常出现,将不再一一详细解释,如果有不清楚含义的glog宏可以去看看文档。 与第一行代码类似,第二行代码是确保用户没有同时提供snapshot和weights参数,这两个参数都是继续之前的训练或者进行fine-tuning的,如果同时指明了这两个标志,则不知道到底应该从哪个路径的文件去读入模型的相关参数更为合适。

然后出现了SolverParameter solver_param的声明和解析的代码:

1
2
caffe::SolverParameter solver_param;
caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);

SolverParameter是通过Google Protocol Buffer自动生成的一个类,如果有不清楚的可以参考上一篇文章。而具体的解析函数将在下一部分具体解释。

接下来这一部分的代码是根据用户的设置来选择caffe工作的模式(GPU或CPU)以及使用哪些GPU(caffe已经支持了多GPU同时工作!具体参考:官网tutorial的Parallelism部分):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// If the gpus flag is not provided, allow the mode and device to be set
// in the solver prototxt.
if (FLAGS_gpu.size() == 0
    && solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) {
    if (solver_param.has_device_id()) {
        FLAGS_gpu = ""  +
            boost::lexical_cast<string>(solver_param.device_id());
    } else {  // Set default GPU if unspecified
        FLAGS_gpu = "" + boost::lexical_cast<string>(0);
    }
}
vector<int> gpus;
get_gpus(&gpus);
if (gpus.size() == 0) {
  LOG(INFO) << "Use CPU.";
  Caffe::set_mode(Caffe::CPU);
} else {
  ostringstream s;
  for (int i = 0; i < gpus.size(); ++i) {
    s << (i ? ", " : "") << gpus[i];
  }
  LOG(INFO) << "Using GPUs " << s.str();

  solver_param.set_device_id(gpus[0]);
  Caffe::SetDevice(gpus[0]);
  Caffe::set_mode(Caffe::GPU);
  Caffe::set_solver_count(gpus.size());
}

首先是判断用户在Command Line中是否输入了gpu相关的参数,如果没有(FLAGS_gpu.size()==0)但是用户在solver的prototxt定义中提供了相关的参数,那就把相关的参数放到FLAGS_gpu中,如果用户仅仅是选择了在solver的prototxt定义中选择了GPU模式,但是没有指明具体的gpu_id,那么就默认设置为0。

接下来的代码则通过一个get_gpus的函数,将存放在FLAGS_gpu中的string转成了一个vector,并完成了具体的设置。

下面的代码声明并通过SolverRegistry初始化了一个指向Solver类型的shared_ptr。并通过这个shared_ptr指明了在遇到系统信号(用户按了ctrl+c或者关闭了当前的terminal)时的处理方式。

1
2
3
4
5
6
7
8
caffe::SignalHandler signal_handler(
      GetRequestedAction(FLAGS_sigint_effect),
      GetRequestedAction(FLAGS_sighup_effect));

shared_ptr<caffe::Solver<float> >
    solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

solver->SetActionFunction(signal_handler.GetActionFunction());

接下来判断了一下用户是否定义了snapshot或者weights这两个参数中的一个,如果定义了则需要通过Solver提供的接口从snapshot或者weights文件中去读取已经训练好的网络的参数:

1
2
3
4
5
6
if (FLAGS_snapshot.size()) {
  LOG(INFO) << "Resuming from " << FLAGS_snapshot;
  solver->Restore(FLAGS_snapshot.c_str());
} else if (FLAGS_weights.size()) {
  CopyLayers(solver.get(), FLAGS_weights);
}

最后,如果用户设置了要使用多个gpu,那么要声明一个P2PSync类型的对象,并通过这个对象来完成多gpu的计算,这一部分的代码,这一系列的文章会暂时先不涉及。而如果是只使用单个gpu,那么就通过Solver的Solve()开始具体的优化过程。在优化结束之后,函数将0值返回给main函数,整个train过程到这里也就结束了:

1
2
3
4
5
6
7
8
9
if (gpus.size() > 1) {
  caffe::P2PSync<float> sync(solver, NULL, solver->param());
  sync.run(gpus);
} else {
  LOG(INFO) << "Starting Optimization";
  solver->Solve();
}
LOG(INFO) << "Optimization Done.";
return 0;

上面的代码中涉及了很多Solver这个类的接口,这些内容都将在下一篇文章中进行具体的分析。

SolverParameter的具体解析过程

前面提到了SolverParameter是通过ReadSolverParamsFromTextFileOrDie来完成解析的,这个函数的实现在/CAFFE_ROOT/src/caffe/util/upgrade_proto.cpp里,我们来看一下具体的过程:

1
2
3
4
5
6
7
// Read parameters from a file into a SolverParameter proto message.
void ReadSolverParamsFromTextFileOrDie(const string& param_file,
                                       SolverParameter* param) {
  CHECK(ReadProtoFromTextFile(param_file, param))
      << "Failed to parse SolverParameter file: " << param_file;
  UpgradeSolverAsNeeded(param_file, param);
}

这里调用了先后调用了两个函数,首先是ReadProtoFromTextFile,这个函数的作用是从param_file这个路径去读取solver的定义,并将文件中的内容解析存到param这个指针指向的对象,具体的实现在/CAFFE_ROOT/src/caffe/util/io.cpp的开始:

1
2
3
4
5
6
7
8
9
bool ReadProtoFromTextFile(const char* filename, Message* proto) {
  int fd = open(filename, O_RDONLY);
  CHECK_NE(fd, -1) << "File not found: " << filename;
  FileInputStream* input = new FileInputStream(fd);
  bool success = google::protobuf::TextFormat::Parse(input, proto);
  delete input;
  close(fd);
  return success;
}

这段代码首先是打开了文件,并且读取到了一个FileInputStream的指针中,然后通过protobuf的TextFormat::Parse函数完成了解析。

然后UpgradeSolverAsNeeded完成了新老版本caffe.proto的兼容处理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// Check for deprecations and upgrade the SolverParameter as needed.
bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) {
  bool success = true;
  // Try to upgrade old style solver_type enum fields into new string type
  if (SolverNeedsTypeUpgrade(*param)) {
    LOG(INFO) << "Attempting to upgrade input file specified using deprecated "
              << "'solver_type' field (enum)': " << param_file;
    if (!UpgradeSolverType(param)) {
      success = false;
      LOG(ERROR) << "Warning: had one or more problems upgrading "
                 << "SolverType (see above).";
    } else {
      LOG(INFO) << "Successfully upgraded file specified using deprecated "
                << "'solver_type' field (enum) to 'type' field (string).";
      LOG(WARNING) << "Note that future Caffe releases will only support "
                   << "'type' field (string) for a solver's type.";
    }
  }
  return success;
}

主要的问题就是在旧版本中Solver的type是enum类型,而新版本的变为了string。

总结

本文从主要分析了caffe.cpp中实现各种具体功能的函数的调用的机制,以及在Command Line中用户输入的各种参数是怎么解析的,以及最常用的train函数的具体代码。通过这些分析,我们对Solver类型的接口有了一个初步的认识和了解,在下一篇文章中,我们将去具体地分析Solver的实现。

在Caffe中定义一个网络是通过编辑一个prototxt文件来完成的,一个简单的网络定义文件如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
name: "ExampleNet"
layer {
  name: "data"
  type: "Data"
  top: "data"
  top: "label"
  data_param {
    source: "path/to/train_database"
    batch_size: 64
    backend: LMDB
  }
}
layer {
  name: "conv1"
  type: "Convolution"
  bottom: "data"
  top: "conv1"
  convolution_param {
    num_output: 20
    kernel_size: 5
    stride: 1
  }
}
layer {
  name: "ip1"
  type: "InnerProduct"
  bottom: "conv1"
  top: "ip1"
  inner_product_param {
    num_output: 500
  }
}
layer {
  name: "loss"
  type: "SoftmaxWithLoss"
  bottom: "ip1"
  bottom: "label"
  top: "loss"
}

这个网络定义了一个name为ExampleNet的网络,这个网络的输入数据是LMDB数据,batch_size为64,包含了一个卷积层和一个全连接层,训练的loss function为SoftmaxWithLoss。通过这种简单的key: value描述方式,用户可以很方便的定义自己的网络,利用Caffe来训练和测试网络,验证自己的想法。

Caffe中定义了丰富的layer类型,每个类型都有对应的一些参数来描述这一个layer。为了说明的方便,接下来将通过一个简单的例子来展示Caffe是如何使用Google Protocol Buffer来完成Solver和Net的定义。

首先我们需要了解Google Protocol Buffer定义data schema的方式,Google Protocol Buffer通过一种类似于C++的语言来定义数据结构,下面是官网上一个典型的AddressBook例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// AddressBook.proto
package tutorial;

message Person {
  required string name = 1;
  required int32 id = 2;
  optional string email = 3;

  enum PhoneType {
    MOBILE = 0;
    HOME = 1;
    WORK = 2;
  }

  message PhoneNumber {
    required string number = 1;
    optional PhoneType type = 2 [default = HOME];
  }

  repeated PhoneNumber phone = 4;
}

message AddressBook {
  repeated Person person = 1;
}

第2行的package tutorial类似于C++里的namespace,message可以简单的理解为一个class,message可以嵌套定义。每一个field除了一般的int32和string等类型外,还有一个属性来表明这个field是required,optional或者’repeated’。required的field必须存在,相对应的optional的就可以不存在,repeated的field可以出现0次或者多次。这一点对于Google Protocol Buffer的兼容性很重要,比如新版本的AddressBook添加了一个string类型的field,只有把这个field的属性设置为optional,就可以保证新版本的代码读取旧版本的数据也不会出错,新版本只会认为旧版本的数据没有提供这个optional field,会直接使用default。同时我们也可以定义enum类型的数据。每个field等号右侧的数字可以理解为在实际的binary encoding中这个field对应的key值,通常的做法是将经常使用的field定义为0-15的数字,可以节约存储空间(涉及到具体的encoding细节,感兴趣的同学可以看看官网的解释),其余的field使用较大的数值。

类似地在caffe/src/caffe/proto/中有一个caffe.proto文件,其中对layer的部分定义为:

1
2
3
4
5
6
7
message LayerParameter {
  optional string name = 1; // the layer name
  optional string type = 2; // the layer type
  repeated string bottom = 3; // the name of each bottom blob
  repeated string top = 4; // the name of each top blob
//  other fields
}

在定义好了data schema之后,需要使用protoc compiler来编译定义好的proto文件。常用的命令为:

protoc -I=/protofile/directory –cpp_out=/output/directory /path/to/protofile

-I之后为proto文件的路径,–cpp_out为编译生成的.h和.cc文件的路径,最后是proto文件的路径。编译之后会生成AddressBook.pb.h和AddressBook/pb.cc文件,其中包含了大量的接口函数,用户可以利用这些接口函数获取和改变某个field的值。对应上面的data schema定义,有这样的一些接口函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// name
inline bool has_name() const;
inline void clear_name();
inline const ::std::string& name() const;  //getter
inline void set_name(const ::std::string& value);  //setter
inline void set_name(const char* value);  //setter
inline ::std::string* mutable_name();

// email
inline bool has_email() const;
inline void clear_email();
inline const ::std::string& email() const; //getter
inline void set_email(const ::std::string& value);  //setter
inline void set_email(const char* value);  //setter
inline ::std::string* mutable_email();

// phone
inline int phone_size() const;
inline void clear_phone();
inline const ::google::protobuf::RepeatedPtrField< ::tutorial::Person_PhoneNumber >& phone() const;
inline ::google::protobuf::RepeatedPtrField< ::tutorial::Person_PhoneNumber >* mutable_phone();
inline const ::tutorial::Person_PhoneNumber& phone(int index) const;
inline ::tutorial::Person_PhoneNumber* mutable_phone(int index);
inline ::tutorial::Person_PhoneNumber* add_phone();

每个类都有对应的setter和getter,因为phone是repeated类型的,所以还多了通过index来获取和改变某一个元素的setter和getter,phone还有一个获取数量的phone_size函数。

官网上的tutorial是通过bool ParseFromIstream(istream* input);来从binary的数据文件里解析数据,为了更好地说明Caffe中读取数据的方式,我稍微修改了代码,使用了和Caffe一样的方式通过TextFormat::Parse来解析文本格式的数据。具体的代码如下:

#include <iostream>
#include <fstream>
#include <string>
#include <algorithm>
#include <stdint.h>
#include <fcntl.h>
#include <unistd.h>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include "addressBook.pb.h"

using namespace std;
using google::protobuf::io::FileInputStream;
using google::protobuf::io::FileOutputStream;
using google::protobuf::io::ZeroCopyInputStream;
using google::protobuf::io::CodedInputStream;
using google::protobuf::io::ZeroCopyOutputStream;
using google::protobuf::io::CodedOutputStream;
using google::protobuf::Message;

// Iterates though all people in the AddressBook and prints info about them.
void ListPeople(const tutorial::AddressBook& address_book) {
  for (int i = 0; i < address_book.person_size(); i++) {
    const tutorial::Person& person = address_book.person(i);

    cout << "Person ID: " << person.id() << endl;
    cout << "  Name: " << person.name() << endl;
    if (person.has_email()) {
      cout << "  E-mail address: " << person.email() << endl;
    }

    for (int j = 0; j < person.phone_size(); j++) {
      const tutorial::Person::PhoneNumber& phone_number = person.phone(j);

      switch (phone_number.type()) {
        case tutorial::Person::MOBILE:
          cout << "  Mobile phone #: ";
          break;
        case tutorial::Person::HOME:
          cout << "  Home phone #: ";
          break;
        case tutorial::Person::WORK:
          cout << "  Work phone #: ";
          break;
      }
      cout << phone_number.number() << endl;
    }
  }
}

// Main function:  Reads the entire address book from a file and prints all
//   the information inside.
int main(int argc, char* argv[]) {
  // Verify that the version of the library that we linked against is
  // compatible with the version of the headers we compiled against.
  GOOGLE_PROTOBUF_VERIFY_VERSION;

  if (argc != 2) {
    cerr << "Usage:  " << argv[0] << " ADDRESS_BOOK_FILE" << endl;
    return -1;
  }

  tutorial::AddressBook address_book;

  {
    // Read the existing address book.
    int fd = open(argv[1], O_RDONLY);
    FileInputStream* input = new FileInputStream(fd);
    if (!google::protobuf::TextFormat::Parse(input, &address_book)) {
      cerr << "Failed to parse address book." << endl;
      delete input;
      close(fd);
      return -1;
    }
  }

  ListPeople(address_book);

  // Optional:  Delete all global objects allocated by libprotobuf.
  google::protobuf::ShutdownProtobufLibrary();

  return 0;
}

读取和解析数据的代码:

1
2
3
4
5
int fd = open(argv[1], O_RDONLY);
FileInputStream* input = new FileInputStream(fd);
if (!google::protobuf::TextFormat::Parse(input, &address_book)) {
  cerr << "Failed to parse address book." << endl;
}

这一段代码将input解析为我们设计的数据格式,写入到address_book中。之后再调用ListPeople函数输出数据,来验证数据确实是按照我们设计的格式来存储和读取的。ListPeople函数中使用了之前提到的各个getter接口函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# ExampleAddressBook.prototxt
person {
  name: "Alex K"
  id: 1
  email: "kongming.liang@abc.com"
  phone {
    number: "+86xxxxxxxxxxx"
    type: MOBILE
  }
}

person {
  name: "Andrew D"
  id: 2
  email: "xuesong.deng@vipl.ict.ac.cn"
  phone {
    number: "+86xxxxxxxxxxx"
    type: MOBILE
  }
  phone {
    number: "+86xxxxxxxxxxx"
    type: WORK
  }
}

上面的文件的解析结果如图所示:

Caffe是一个基于C++和cuda开发的深度学习框架。其使用和开发的便捷特性使其成为近年来机器学习和计算机视觉领域最广为使用的框架。

笔者使用Caffe做各种实验也有一段时间了,除了Caffe支持的各种计算方式(卷积/pooling/全连接等)之外,在自己的使用中开始遇到一些需要自己定义的网络,也有一些新的模型,比如2015ImageNet的冠军里面的shortcut结构,都需要实现新的layer来完成计算。为了去改造Caffe,首先学习Caffe的源码,接下来一系列的博客将记录和分享这个学习的过程。作为一个C++的菜鸟,如有错误,希望读者指出。

首先,简单介绍一下Caffe的代码结构。Caffe主要包含了4个大类:

  • Solver: An interface for classes that perform optimization on Nets
  • Net: Connects Layers together into a directed acyclic graph (DAG) specified by a NetParameter
  • Layer: An interface for the units of computation which can be composed into a Net
  • Blob: A wrapper around SyncedMemory holders serving as the basic computational unit through which Layers, Nets, and Solvers interact

其中Solver这个类实现了优化函数的封装,其中有一个protected的成员:shared_ptr<Net > net_;,这个成员是一个指向Net类型的智能指针(shared_ptr),Solver正是通过这个指针来和网络Net来交互并完成模型的优化。不同的子类分别实现了不同的优化方法:SGDSolver, NesterovSolver, AdaGradSolver, RMSPropSolver, AdaDeltaSolver和AdamSolver。具体每个Solver对应的优化方法参考:Caffe Solver Methods。 类似地Layer这个类派生出了很多子类,这些子类实现了Data的读取和Convolution, Pooling, InnerProduct等各种功能的layer。 Net则是对整个网络的一个封装,其中有一个成员为:vector<shared_ptr<Layer > > layers_;,这个vector中包含了整个网络中每一层layer的智能指针,Net通过调用这些layer各自的forward()和backward()接口实现了网络整体的ForwardBackward()。 Blob则是Caffe对数据的封装,在整个网络的计算中,不管是数据还是网络的参数和梯度都是这个类的对象,均为num\*channel\*width\*height形式的数据。

目前,初步的打算是从外部接口逐渐深入,首先学习caffe的主函数的接口,然后是Solver特别是默认使用的SGDSolver的具体实现,调用了哪些Net的接口等;接下来学习和了解Net是如何封装各个Layer来组成一个整体的网络,还有就是Net中如何利用Layer的接口完成数据的forward和backward的传导;最后具体了解不同的Layer如何实现自定义的forward()和backward()接口,完成最重要的计算。虽然目前Caffe已经实现了多GPU并行化的功能,但是在这个学习的过程中,我将暂时忽略这一部分的代码,而集中注意力到前面所述的这几部分内容上。

除了清晰的代码结构,让Caffe变得易用更应该归功于Google Protocol Buffer的使用。Google Protocol Buffer是Google开发的一个用于serializing结构化数据的开源工具:

Protocol buffers are a language-neutral, platform-neutral extensible mechanism for serializing structured data.

Caffe使用这个工具来定义Solver和Net,以及Net中每一个layer的参数。这使得只是想使用Caffe目前支持的Layer(已经非常丰富了)来做一些实验或者demo的用户可以不去和代码打交道,只需要在*.prototxt文件中描述自己的Solver和Net即可,再通过Caffe提供的command line interfaces就可以完成模型的train/finetune/test等功能。下一篇文章将通过一个简单的例子来展示Google Protocol Buffer的作用和便捷之处。

最近玩了一下今年Science上发表的一篇关于聚类的文章。记录一下。

算法的过程并不复杂,但确实十分聪明。算法的输入只需要一个distance的矩阵,第i行第j列的元素就是第i个数据和第j个数据的距离。

文章首先定义了一个数据的密度(density):

其中:

在作者的给出的代码里对密度的定义给出了一个更合理的公式:

简单的来说,这两个公式都反映了一个数据其一个小邻域内的其他数据的多少或者稠密程度,也就估计出来了这个数据的一个局部的密度。

值得一提的是dc这个值,对算法的影响还是比较大,文章里给出的建议是升序排列distance,然后取1%到2%的那个distance作为dc。

然后需要计算delta:

这个值越高意味着这个数据越有可能是聚类的中心,因为delta值是这个数据与比自己的密度更大的数据之间最小的距离,这个距离越大,说明这个数据远离了其他比自己密度大的数据,是自己的邻域里的局部最大,也就是题目中说的density peaks。

当然在保证这个数据有很大的delta值的情况下,也需要保证这个数据有很高的density,否则就意味着这个数据是一个噪声,因为它既远离了其他的高密度区域,自己也不处于一个高密度区域(很小的density值)。

在作者给出的matlab版本的代码里,在计算完了密度和delta之后,画出了横轴为密度,纵轴为delta的decision graph,然后可以认为选定一个点来决定最小的delta值和密度值,在dc选择合理的情况下,通常可以很轻松地作出这个决定。如下图:

Decision graph

接下来需要对除了聚类中心的数据进行指派类别,第i个数据的类别和比它有更高密度的子集里的最近邻的那个数据的类别一样,通过这种扩散的方法,类别的指派从各个聚类中心(也是各个类的density peak)向低密度的区域指派。通常的聚类方法比如kmeans,指派的时候,只是简单的将数据指派到距离最近的那个中心,但是如果数据的分布本身不是球形的,而是椭圆甚至某种奇怪的曲线的话,这样的指派显然是错误的。

但是如果按照本文的方法,就可以保证数据的真实分布被反映出来。还有一个优点就是类别的指派过程是单步的,不像kmeans等聚类方法需要迭代。

然后就是最后一个步骤,找到halo区域,并且定义为噪声。首先对每一个类定义一个border region,这个区域是属于这个类别而且有距离其他类别数据小于dc的数据,这些数据会出现在各个类别的交界区域。然后将这个区域内最大的密度定义为$\rho_{b}$。

在这个类当中,所有密度比$\rho_{b}$小的数据都被定义为噪声。这样假设的合理性在于,和其他类的交界区域已经是这个类比较低密度的区域,如果比这个密度还要低,说明这些数据已经远离了这个类的核心区域,理应被定为噪声。作者提供的测试数据得到的结果见下图:

Clustering result

总结一下,这个算法的核心思想在于定义了密度,我的体会是这个密度在数据量比较充足的情况下,和数据本身的真实的概率分布很接近,也就是说,用这个算法,不论去估计什么样形式的数据分布(高斯、多项式、卡方,甚至自定义的)都依靠数据估计得到比较准确的密度值,而不是去对数据的分布预先做出假设。但这也带来了一个隐患,在实际的模式识别问题中,我们面对的数据分布通常是高维的,很多时候也是不充足的,通过这个方法估计得到的密度值也就可能不够准确。

另外,数据充足与否是一个相对的概念,不是说有一万个数据就比一千个数据充足,如果一千个数据的那个分布比较简单,且数据都靠近分布的核心区域,比较稠密,那也是充足的;一万个数据,但是如果数据分布很复杂,变化很多,这一万个数据又比较分散,也可以认为这一万个数据是不充足的。

最近还在用这个算法做更多的实验,有新的结果之后会保持更新。

再另外,作者主页提供的matlab代码运行效率实在太慢,我自己实现了一个稍微快一些的版本,可以看这里:

点我