Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modular model definitions #1290

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
PROJECT := caffe

CONFIG_FILE := Makefile.config
CONFIG_FILE ?= Makefile.config
include $(CONFIG_FILE)

BUILD_DIR_LINK := $(BUILD_DIR)
Expand Down
4 changes: 3 additions & 1 deletion include/caffe/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ class Caffe {
inline static void set_mode(Brew mode) { Get().mode_ = mode; }
// Sets the phase.
inline static void set_phase(Phase phase) { Get().phase_ = phase; }
// Sets the random seed of both boost and curand
// Random seed of both boost and curand
static unsigned int get_random_seed();
static void set_random_seed(const unsigned int seed);
// Sets the device. Since we have cublas and curand stuff, set device also
// requires us to reset those values.
Expand All @@ -161,6 +162,7 @@ class Caffe {
curandGenerator_t curand_generator_;
#endif
shared_ptr<RNG> random_generator_;
unsigned int random_generator_seed_;

Brew mode_;
Phase phase_;
Expand Down
92 changes: 75 additions & 17 deletions include/caffe/data_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
#include <utility>
#include <vector>

#include "boost/scoped_ptr.hpp"
#include "boost/weak_ptr.hpp"
#include "boost/random/mersenne_twister.hpp"
#include "boost/random/uniform_real.hpp"
#include "boost/random/variate_generator.hpp"
#include "hdf5.h"

#include "caffe/blob.hpp"
Expand All @@ -16,9 +19,15 @@
#include "caffe/internal_thread.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/blocking_queue.hpp"

namespace caffe {

using boost::weak_ptr;
using boost::mt19937;
using boost::uniform_real;
using boost::variate_generator;

/**
* @brief Provides base for data layers that feed blobs to the Net.
*
Expand Down Expand Up @@ -52,12 +61,17 @@ class BaseDataLayer : public Layer<Dtype> {
bool output_labels_;
};

template <typename Dtype>
class Batch {
public:
Blob<Dtype> data_, label_;
};

template <typename Dtype>
class BasePrefetchingDataLayer :
public BaseDataLayer<Dtype>, public InternalThread {
public:
explicit BasePrefetchingDataLayer(const LayerParameter& param)
: BaseDataLayer<Dtype>(param) {}
explicit BasePrefetchingDataLayer(const LayerParameter& param);
virtual ~BasePrefetchingDataLayer() {}
// LayerSetUp: implements common data layer setup functionality, and calls
// DataLayerSetUp to do special data layer setup for individual layer types.
Expand All @@ -70,22 +84,64 @@ class BasePrefetchingDataLayer :
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual void CreatePrefetchThread();
virtual void JoinPrefetchThread();
// The thread's function
virtual void InternalThreadEntry() {}
// Prefetches batches (asynchronously if to GPU memory)
static const int PREFETCH_COUNT = 3;

protected:
Blob<Dtype> prefetch_data_;
Blob<Dtype> prefetch_label_;
virtual void InternalThreadEntry();
virtual void load_batch(Batch<Dtype>* batch) = 0;

Batch<Dtype> prefetch_[PREFETCH_COUNT];
blocking_queue<Batch<Dtype>*> prefetch_free_;
blocking_queue<Batch<Dtype>*> prefetch_full_;
int device_;

Blob<Dtype> transformed_data_;
};

// Prefetches datums to host memory that can be read by multiple data layers.
class DataLoader {
public:
DataLoader(const DataParameter& param, int index);
~DataLoader();

inline blocking_queue<Datum*>& free() {
return body_.get()->free_;
}
inline blocking_queue<Datum*>& full() {
return body_.get()->full_;
}

protected:
class Body: public InternalThread {
public:
Body(const DataParameter& param, int index);
~Body();

void InternalThreadEntry();

shared_ptr<Dataset<string, Datum> > dataset_;
Dataset<string, Datum>::const_iterator iter_;

blocking_queue<Datum*> free_;
blocking_queue<Datum*> full_;

DISABLE_COPY_AND_ASSIGN(Body);
};

static map<string, weak_ptr<Body> > instances_;
static boost::mutex instances_mutex_;

const string source_;
shared_ptr<Body> body_;

DISABLE_COPY_AND_ASSIGN(DataLoader);
};

template <typename Dtype>
class DataLayer : public BasePrefetchingDataLayer<Dtype> {
class DataLayer: public BasePrefetchingDataLayer<Dtype> {
public:
explicit DataLayer(const LayerParameter& param)
: BasePrefetchingDataLayer<Dtype>(param) {}
explicit DataLayer(const LayerParameter& param);
virtual ~DataLayer();
virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
Expand All @@ -98,10 +154,12 @@ class DataLayer : public BasePrefetchingDataLayer<Dtype> {
virtual inline int MaxTopBlobs() const { return 2; }

protected:
virtual void InternalThreadEntry();
virtual void load_batch(Batch<Dtype>* batch);
DataLoader* next_loader();

shared_ptr<Dataset<string, Datum> > dataset_;
Dataset<string, Datum>::const_iterator iter_;
vector<shared_ptr<DataLoader> > loaders_;
mt19937 rand_engine_;
uniform_real<float> rand_;
};

/**
Expand Down Expand Up @@ -244,7 +302,7 @@ class ImageDataLayer : public BasePrefetchingDataLayer<Dtype> {
protected:
shared_ptr<Caffe::RNG> prefetch_rng_;
virtual void ShuffleImages();
virtual void InternalThreadEntry();
virtual void load_batch(Batch<Dtype>* batch);

vector<std::pair<std::string, int> > lines_;
int lines_id_;
Expand Down Expand Up @@ -317,7 +375,7 @@ class WindowDataLayer : public BasePrefetchingDataLayer<Dtype> {

protected:
virtual unsigned int PrefetchRand();
virtual void InternalThreadEntry();
virtual void load_batch(Batch<Dtype>* batch);

shared_ptr<Caffe::RNG> prefetch_rng_;
vector<std::pair<std::string, vector<int> > > image_database_;
Expand Down
10 changes: 8 additions & 2 deletions include/caffe/internal_thread.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Thread {
Thread(Callable func, A1 a1);
void join();
bool joinable();
void interrupt();
private:
void* thread_;
};
Expand All @@ -26,23 +27,28 @@ class Thread {
*/
class InternalThread {
public:
InternalThread() : thread_(NULL) {}
InternalThread() : thread_(NULL), must_stop_() {}
virtual ~InternalThread();

/** Returns true if the thread was successfully started. **/
bool StartInternalThread();

/** Will not return until the internal thread has exited. */
bool WaitForInternalThreadToExit();
bool StopInternalThread();

bool is_started() const { return thread_ != NULL && thread_->joinable(); }

bool must_stop() {
return must_stop_;
}

protected:
/* Implement this method in your subclass
with the code you want your thread to run. */
virtual void InternalThreadEntry() {}

caffe::Thread* thread_;
bool must_stop_;
};

} // namespace caffe
Expand Down
4 changes: 4 additions & 0 deletions include/caffe/syncedmem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class SyncedMemory {
SyncedHead head() { return head_; }
size_t size() { return size_; }

#ifndef CPU_ONLY
void async_gpu_push(const cudaStream_t& stream);
#endif

private:
void to_cpu();
void to_gpu();
Expand Down
84 changes: 84 additions & 0 deletions include/caffe/util/blocking_queue.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#ifndef CAFFE_UTIL_BLOCKING_QUEUE_H_
#define CAFFE_UTIL_BLOCKING_QUEUE_H_

#include <queue>
#include "boost/thread.hpp"

namespace caffe {

template<typename T>
class blocking_queue {
public:
blocking_queue()
: last_wait_log_(time(0)),
pops_() {
}

void push(const T& t) {
boost::mutex::scoped_lock lock(mutex_);
queue_.push(t);
lock.unlock();
condition_.notify_one();
}

bool empty() const {
boost::mutex::scoped_lock lock(mutex_);
return queue_.empty();
}

bool try_pop(T& t) {
boost::mutex::scoped_lock lock(mutex_);

if (queue_.empty())
return false;

t = queue_.front();
queue_.pop();
return true;
}

T pop(const string& log_on_wait = "") {
boost::mutex::scoped_lock lock(mutex_);

while (queue_.empty()) {
if (!log_on_wait.empty()) {
time_t now = time(0);
if (now - last_wait_log_ > 5) {
last_wait_log_ = now;
LOG(INFO) << log_on_wait;
}
}
condition_.wait(lock);
}

T t = queue_.front();
queue_.pop();
pops_++;
return t;
}

// Return element without removing it
T peek() {
boost::mutex::scoped_lock lock(mutex_);

while (queue_.empty())
condition_.wait(lock);

return queue_.front();
}

inline uint64_t pops() {
return pops_;
}

private:
std::queue<T> queue_;
mutable boost::mutex mutex_;
boost::condition_variable condition_;
time_t last_wait_log_;
uint64_t pops_;
};

} // namespace caffe

#endif
4 changes: 4 additions & 0 deletions include/caffe/util/thread.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ bool Thread::joinable() {
return static_cast<boost::thread*>(this->thread_)->joinable();
}

void Thread::interrupt() {
static_cast<boost::thread*>(this->thread_)->interrupt();
}

} // namespace caffe

#endif
5 changes: 5 additions & 0 deletions src/caffe/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ Caffe::~Caffe() {
}
}

unsigned int Caffe::get_random_seed() {
return Get().random_generator_seed_;
}

void Caffe::set_random_seed(const unsigned int seed) {
// Curand seed
static bool g_curand_availability_logged = false;
Expand All @@ -124,6 +128,7 @@ void Caffe::set_random_seed(const unsigned int seed) {
}
// RNG seed
Get().random_generator_.reset(new RNG(seed));
Get().random_generator_seed_ = seed;
}

void Caffe::SetDevice(const int device_id) {
Expand Down
9 changes: 6 additions & 3 deletions src/caffe/internal_thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
namespace caffe {

InternalThread::~InternalThread() {
WaitForInternalThreadToExit();
StopInternalThread();
if (thread_ != NULL) {
delete thread_;
}
}

bool InternalThread::StartInternalThread() {
if (!WaitForInternalThreadToExit()) {
if (!StopInternalThread()) {
return false;
}
must_stop_ = false;
try {
thread_ = new caffe::Thread
(&InternalThread::InternalThreadEntry, this);
Expand All @@ -25,8 +26,10 @@ bool InternalThread::StartInternalThread() {
}

/** Will not return until the internal thread has exited. */
bool InternalThread::WaitForInternalThreadToExit() {
bool InternalThread::StopInternalThread() {
must_stop_ = true;
if (is_started()) {
thread_->interrupt();
try {
thread_->join();
} catch (...) {
Expand Down
Loading