diff --git a/docs/tutorial/interfaces.md b/docs/tutorial/interfaces.md index 40602948cc3..9006179d0f1 100644 --- a/docs/tutorial/interfaces.md +++ b/docs/tutorial/interfaces.md @@ -50,6 +50,13 @@ For a full example of fine-tuning, see examples/finetuning_on_flickr_style, but # query the first device caffe device_query -gpu 0 +**Parallelism**: the `-gpu` flag to the `caffe` tool can take a comma separated list of IDs to run on multiple GPUs. A solver and net will be instantiated for each GPU so the batch size is effectively multiplied by the number of GPUs. To reproduce single GPU training, reduce the batch size in the network definition accordingly. + + # train on GPUs 0 & 1 (doubling the batch size) + caffe train -solver examples/mnist/lenet_solver.prototxt -gpu 0,1 + # train on all GPUs (multiplying batch size by number of devices) + caffe train -solver examples/mnist/lenet_solver.prototxt -gpu all + ## Python The Python interface -- pycaffe -- is the `caffe` module and its scripts in caffe/python. `import caffe` to load models, do forward and backward, handle IO, visualize networks, and even instrument model solving. All model data, derivatives, and parameters are exposed for reading and writing. diff --git a/include/caffe/caffe.hpp b/include/caffe/caffe.hpp index 3c829f2f9b0..68a5e1d1d1a 100644 --- a/include/caffe/caffe.hpp +++ b/include/caffe/caffe.hpp @@ -10,6 +10,7 @@ #include "caffe/layer.hpp" #include "caffe/layer_factory.hpp" #include "caffe/net.hpp" +#include "caffe/parallel.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/solver.hpp" #include "caffe/util/benchmark.hpp" diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp index 5f86bc2625b..1df6b9a14fb 100644 --- a/include/caffe/common.hpp +++ b/include/caffe/common.hpp @@ -98,12 +98,12 @@ void GlobalInit(int* pargc, char*** pargv); class Caffe { public: ~Caffe(); - inline static Caffe& Get() { - if (!singleton_.get()) { - singleton_.reset(new Caffe()); - } - return *singleton_; - } + + // Thread local context for Caffe. Moved to common.cpp instead of + // including boost/thread.hpp to avoid a boost/NVCC issues (#1009, #1010) + // on OSX. Also fails on Linux with CUDA 7.0.18. + static Caffe& Get(); + enum Brew { CPU, GPU }; // This random number generator facade hides boost and CUDA rng @@ -149,6 +149,11 @@ class Caffe { static void SetDevice(const int device_id); // Prints the current GPU status. static void DeviceQuery(); + // Parallel training info + inline static int solver_count() { return Get().solver_count_; } + inline static void set_solver_count(int val) { Get().solver_count_ = val; } + inline static bool root_solver() { return Get().root_solver_; } + inline static void set_root_solver(bool val) { Get().root_solver_ = val; } protected: #ifndef CPU_ONLY @@ -158,7 +163,8 @@ class Caffe { shared_ptr random_generator_; Brew mode_; - static shared_ptr singleton_; + int solver_count_; + bool root_solver_; private: // The private constructor to avoid duplicate instantiation. diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp index 3958cb7ecb0..552d814131e 100644 --- a/include/caffe/data_layers.hpp +++ b/include/caffe/data_layers.hpp @@ -5,16 +5,17 @@ #include #include -#include "boost/scoped_ptr.hpp" #include "hdf5.h" #include "caffe/blob.hpp" #include "caffe/common.hpp" +#include "caffe/data_reader.hpp" #include "caffe/data_transformer.hpp" #include "caffe/filler.hpp" #include "caffe/internal_thread.hpp" #include "caffe/layer.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/util/blocking_queue.hpp" #include "caffe/util/db.hpp" namespace caffe { @@ -33,6 +34,8 @@ class BaseDataLayer : public Layer { // This method may not be overridden except by the BasePrefetchingDataLayer. virtual void LayerSetUp(const vector*>& bottom, const vector*>& top); + // Data layers should be shared by multiple solvers in parallel + virtual inline bool ShareInParallel() const { return true; } virtual void DataLayerSetUp(const vector*>& bottom, const vector*>& top) {} // Data layers have no bottoms, so reshaping is trivial. @@ -50,12 +53,17 @@ class BaseDataLayer : public Layer { bool output_labels_; }; +template +class Batch { + public: + Blob data_, label_; +}; + template class BasePrefetchingDataLayer : public BaseDataLayer, public InternalThread { public: - explicit BasePrefetchingDataLayer(const LayerParameter& param) - : BaseDataLayer(param) {} + explicit BasePrefetchingDataLayer(const LayerParameter& param); // LayerSetUp: implements common data layer setup functionality, and calls // DataLayerSetUp to do special data layer setup for individual layer types. // This method may not be overridden. @@ -67,36 +75,38 @@ class BasePrefetchingDataLayer : virtual void Forward_gpu(const vector*>& bottom, const vector*>& top); - virtual void CreatePrefetchThread(); - virtual void JoinPrefetchThread(); - // The thread's function - virtual void InternalThreadEntry() {} + // Prefetches batches (asynchronously if to GPU memory) + static const int PREFETCH_COUNT = 3; protected: - Blob prefetch_data_; - Blob prefetch_label_; + virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch) = 0; + + Batch prefetch_[PREFETCH_COUNT]; + BlockingQueue*> prefetch_free_; + BlockingQueue*> prefetch_full_; + Blob transformed_data_; }; template class DataLayer : public BasePrefetchingDataLayer { public: - explicit DataLayer(const LayerParameter& param) - : BasePrefetchingDataLayer(param) {} + explicit DataLayer(const LayerParameter& param); virtual ~DataLayer(); virtual void DataLayerSetUp(const vector*>& bottom, const vector*>& top); - + // DataLayer uses DataReader instead for sharing for parallelism + virtual inline bool ShareInParallel() const { return false; } virtual inline const char* type() const { return "Data"; } virtual inline int ExactNumBottomBlobs() const { return 0; } virtual inline int MinTopBlobs() const { return 1; } virtual inline int MaxTopBlobs() const { return 2; } protected: - virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch); - shared_ptr db_; - shared_ptr cursor_; + DataReader reader_; }; /** @@ -111,6 +121,8 @@ class DummyDataLayer : public Layer { : Layer(param) {} virtual void LayerSetUp(const vector*>& bottom, const vector*>& top); + // Data layers should be shared by multiple solvers in parallel + virtual inline bool ShareInParallel() const { return true; } // Data layers have no bottoms, so reshaping is trivial. virtual void Reshape(const vector*>& bottom, const vector*>& top) {} @@ -144,6 +156,8 @@ class HDF5DataLayer : public Layer { virtual ~HDF5DataLayer(); virtual void LayerSetUp(const vector*>& bottom, const vector*>& top); + // Data layers should be shared by multiple solvers in parallel + virtual inline bool ShareInParallel() const { return true; } // Data layers have no bottoms, so reshaping is trivial. virtual void Reshape(const vector*>& bottom, const vector*>& top) {} @@ -185,6 +199,8 @@ class HDF5OutputLayer : public Layer { virtual ~HDF5OutputLayer(); virtual void LayerSetUp(const vector*>& bottom, const vector*>& top); + // Data layers should be shared by multiple solvers in parallel + virtual inline bool ShareInParallel() const { return true; } // Data layers have no bottoms, so reshaping is trivial. virtual void Reshape(const vector*>& bottom, const vector*>& top) {} @@ -235,7 +251,7 @@ class ImageDataLayer : public BasePrefetchingDataLayer { protected: shared_ptr prefetch_rng_; virtual void ShuffleImages(); - virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch); vector > lines_; int lines_id_; @@ -307,7 +323,7 @@ class WindowDataLayer : public BasePrefetchingDataLayer { protected: virtual unsigned int PrefetchRand(); - virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch); shared_ptr prefetch_rng_; vector > > image_database_; diff --git a/include/caffe/data_reader.hpp b/include/caffe/data_reader.hpp new file mode 100644 index 00000000000..8ed5542cb8d --- /dev/null +++ b/include/caffe/data_reader.hpp @@ -0,0 +1,82 @@ +#ifndef CAFFE_DATA_READER_HPP_ +#define CAFFE_DATA_READER_HPP_ + +#include +#include +#include + +#include "caffe/common.hpp" +#include "caffe/internal_thread.hpp" +#include "caffe/util/blocking_queue.hpp" +#include "caffe/util/db.hpp" + +namespace caffe { + +/** + * @brief Reads data from a source to queues available to data layers. + * A single reading thread is created per source, even if multiple solvers + * are running in parallel, e.g. for multi-GPU training. This makes sure + * databases are read sequentially, and that each solver accesses a different + * subset of the database. Data is distributed to solvers in a round-robin + * way to keep parallel training deterministic. + */ +class DataReader { + public: + explicit DataReader(const LayerParameter& param); + ~DataReader(); + + inline BlockingQueue& free() const { + return queue_pair_->free_; + } + inline BlockingQueue& full() const { + return queue_pair_->full_; + } + + protected: + // Queue pairs are shared between a body and its readers + class QueuePair { + public: + explicit QueuePair(int size); + ~QueuePair(); + + BlockingQueue free_; + BlockingQueue full_; + + DISABLE_COPY_AND_ASSIGN(QueuePair); + }; + + // A single body is created per source + class Body : public InternalThread { + public: + explicit Body(const LayerParameter& param); + virtual ~Body(); + + protected: + void InternalThreadEntry(); + void read_one(db::Cursor* cursor, QueuePair* qp); + + const LayerParameter param_; + BlockingQueue > new_queue_pairs_; + + friend class DataReader; + + DISABLE_COPY_AND_ASSIGN(Body); + }; + + // A source is uniquely identified by its layer name + path, in case + // the same database is read from two different locations in the net. + static inline string source_key(const LayerParameter& param) { + return param.name() + ":" + param.data_param().source(); + } + + const shared_ptr queue_pair_; + shared_ptr body_; + + static map > bodies_; + +DISABLE_COPY_AND_ASSIGN(DataReader); +}; + +} // namespace caffe + +#endif // CAFFE_DATA_READER_HPP_ diff --git a/include/caffe/internal_thread.hpp b/include/caffe/internal_thread.hpp index 815ca54605e..6a8c5a02892 100644 --- a/include/caffe/internal_thread.hpp +++ b/include/caffe/internal_thread.hpp @@ -14,18 +14,22 @@ namespace caffe { /** * Virtual class encapsulate boost::thread for use in base class * The child class will acquire the ability to run a single thread, - * by reimplementing the virutal function InternalThreadEntry. + * by reimplementing the virtual function InternalThreadEntry. */ class InternalThread { public: InternalThread() : thread_() {} virtual ~InternalThread(); - /** Returns true if the thread was successfully started. **/ - bool StartInternalThread(); + /** + * Caffe's thread local state will be initialized using the current + * thread values, e.g. device id, solver index etc. The random seed + * is initialized using caffe_rng_rand. + */ + void StartInternalThread(); /** Will not return until the internal thread has exited. */ - bool WaitForInternalThreadToExit(); + void StopInternalThread(); bool is_started() const; @@ -34,6 +38,13 @@ class InternalThread { with the code you want your thread to run. */ virtual void InternalThreadEntry() {} + /* Should be tested when running loops to exit when requested. */ + bool must_stop(); + + private: + void entry(int device, Caffe::Brew mode, int rand_seed, int solver_count, + bool root_solver); + shared_ptr thread_; }; diff --git a/include/caffe/layer.hpp b/include/caffe/layer.hpp index 0771b6a8fb4..a0d1d4ecc94 100644 --- a/include/caffe/layer.hpp +++ b/include/caffe/layer.hpp @@ -11,6 +11,12 @@ #include "caffe/proto/caffe.pb.h" #include "caffe/util/device_alternate.hpp" +/** + Forward declare boost::thread instead of including boost/thread.hpp + to avoid a boost/NVCC issues (#1009, #1010) on OSX. + */ +namespace boost { class mutex; } + namespace caffe { /** @@ -32,7 +38,7 @@ class Layer { * layer. */ explicit Layer(const LayerParameter& param) - : layer_param_(param) { + : layer_param_(param), is_shared_(false) { // Set phase and copy blobs (if there are any). phase_ = param.phase(); if (layer_param_.blobs_size() > 0) { @@ -60,6 +66,7 @@ class Layer { */ void SetUp(const vector*>& bottom, const vector*>& top) { + InitMutex(); CheckBlobCounts(bottom, top); LayerSetUp(bottom, top); Reshape(bottom, top); @@ -85,6 +92,30 @@ class Layer { virtual void LayerSetUp(const vector*>& bottom, const vector*>& top) {} + /** + * @brief Whether a layer should be shared by multiple nets during data + * parallelism. By default, all layers except for data layers should + * not be shared. data layers should be shared to ensure each worker + * solver access data sequentially during data parallelism. + */ + virtual inline bool ShareInParallel() const { return false; } + + /** @brief Return whether this layer is actually shared by other nets. + * If ShareInParallel() is true and using more than one GPU and the + * net has TRAIN phase, then this function is expected return true. + */ + inline bool IsShared() const { return is_shared_; } + + /** @brief Set whether this layer is actually shared by other nets + * If ShareInParallel() is true and using more than one GPU and the + * net has TRAIN phase, then is_shared should be set true. + */ + inline void SetShared(bool is_shared) { + CHECK(ShareInParallel() || !is_shared) + << type() << "Layer does not support sharing."; + is_shared_ = is_shared; + } + /** * @brief Adjust the shapes of top blobs and internal buffers to accommodate * the shapes of the bottom blobs. @@ -396,6 +427,20 @@ class Layer { } } + private: + /** Whether this layer is actually shared by other nets*/ + bool is_shared_; + + /** The mutex for sequential forward if this layer is shared */ + shared_ptr forward_mutex_; + + /** Initialize forward_mutex_ */ + void InitMutex(); + /** Lock forward_mutex_ if this layer is shared */ + void Lock(); + /** Unlock forward_mutex_ if this layer is shared */ + void Unlock(); + DISABLE_COPY_AND_ASSIGN(Layer); }; // class Layer @@ -405,6 +450,8 @@ class Layer { template inline Dtype Layer::Forward(const vector*>& bottom, const vector*>& top) { + // Lock during forward to ensure sequential forward + Lock(); Dtype loss = 0; Reshape(bottom, top); switch (Caffe::mode()) { @@ -435,6 +482,7 @@ inline Dtype Layer::Forward(const vector*>& bottom, default: LOG(FATAL) << "Unknown caffe mode."; } + Unlock(); return loss; } diff --git a/include/caffe/layer_factory.hpp b/include/caffe/layer_factory.hpp index 2fcd93869a0..32e849de0d2 100644 --- a/include/caffe/layer_factory.hpp +++ b/include/caffe/layer_factory.hpp @@ -71,7 +71,9 @@ class LayerRegistry { // Get a layer using a LayerParameter. static shared_ptr > CreateLayer(const LayerParameter& param) { - LOG(INFO) << "Creating layer " << param.name(); + if (Caffe::root_solver()) { + LOG(INFO) << "Creating layer " << param.name(); + } const string& type = param.type(); CreatorRegistry& registry = Registry(); CHECK_EQ(registry.count(type), 1) << "Unknown layer type: " << type diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp index bf997553ee2..1bf07d28d13 100644 --- a/include/caffe/net.hpp +++ b/include/caffe/net.hpp @@ -23,8 +23,9 @@ namespace caffe { template class Net { public: - explicit Net(const NetParameter& param); - explicit Net(const string& param_file, Phase phase); + explicit Net(const NetParameter& param, const Net* root_net = NULL); + explicit Net(const string& param_file, Phase phase, + const Net* root_net = NULL); virtual ~Net() {} /// @brief Initialize a network with a NetParameter. @@ -291,7 +292,8 @@ class Net { size_t memory_used_; /// Whether to compute and display debug info for the net. bool debug_info_; - + /// The root net that actually holds the shared layers in data parallelism + const Net* const root_net_; DISABLE_COPY_AND_ASSIGN(Net); }; diff --git a/include/caffe/parallel.hpp b/include/caffe/parallel.hpp new file mode 100644 index 00000000000..85fc2b55984 --- /dev/null +++ b/include/caffe/parallel.hpp @@ -0,0 +1,118 @@ +#ifndef CAFFE_PARALLEL_HPP_ +#define CAFFE_PARALLEL_HPP_ + +#include + +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/internal_thread.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/solver.hpp" +#include "caffe/syncedmem.hpp" +#include "caffe/util/blocking_queue.hpp" + +namespace caffe { + +// Represents a net parameters. Once a net is created, its parameter buffers can +// be replaced by ones from Params, to allow parallelization. Params ensures +// parameters are allocated in one consecutive array. +template +class Params { + public: + explicit Params(shared_ptr > root_solver); + virtual ~Params() { + } + + inline size_t size() const { + return size_; + } + inline Dtype* data() const { + return data_; + } + inline Dtype* diff() const { + return diff_; + } + + protected: + const size_t size_; // Size of buffers + Dtype* data_; // Network parameters + Dtype* diff_; // Gradient + +DISABLE_COPY_AND_ASSIGN(Params); +}; + +// Params stored in GPU memory. +template +class GPUParams : public Params { + public: + GPUParams(shared_ptr > root_solver, int device); + virtual ~GPUParams(); + + void configure(Solver* solver) const; + + protected: + using Params::size_; + using Params::data_; + using Params::diff_; +}; + +class DevicePair { + public: + DevicePair(int parent, int device) + : parent_(parent), + device_(device) { + } + inline int parent() { + return parent_; + } + inline int device() { + return device_; + } + + // Group GPUs in pairs, by proximity depending on machine's topology + static void compute(const vector devices, vector* pairs); + + protected: + int parent_; + int device_; +}; + +// Synchronous data parallelism using map-reduce between local GPUs. +template +class P2PSync : public GPUParams, public Solver::Callback, + public InternalThread { + public: + explicit P2PSync(shared_ptr > root_solver, + P2PSync* parent, const SolverParameter& param); + virtual ~P2PSync(); + + inline const shared_ptr >& solver() const { + return solver_; + } + + void run(const vector& gpus); + + protected: + void on_start(); + void on_gradients_ready(); + + void InternalThreadEntry(); + + P2PSync* parent_; + vector*> children_; + BlockingQueue*> queue_; + const int initial_iter_; + Dtype* parent_grads_; + shared_ptr > solver_; + + using Params::size_; + using Params::data_; + using Params::diff_; +}; + +} // namespace caffe + +#endif diff --git a/include/caffe/python_layer.hpp b/include/caffe/python_layer.hpp index 2957e7426be..c43c1e8a91b 100644 --- a/include/caffe/python_layer.hpp +++ b/include/caffe/python_layer.hpp @@ -27,6 +27,10 @@ class PythonLayer : public Layer { self_.attr("reshape")(bottom, top); } + virtual inline bool ShareInParallel() const { + return this->layer_param_.python_param().share_in_parallel(); + } + virtual inline const char* type() const { return "Python"; } protected: diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index fbade9389ff..f583324acef 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -17,8 +17,9 @@ namespace caffe { template class Solver { public: - explicit Solver(const SolverParameter& param); - explicit Solver(const string& param_file); + explicit Solver(const SolverParameter& param, + const Solver* root_solver = NULL); + explicit Solver(const string& param_file, const Solver* root_solver = NULL); void Init(const SolverParameter& param); void InitTrainNet(); void InitTestNets(); @@ -32,12 +33,27 @@ class Solver { // methods to restore the state from the appropriate snapshot type. void Restore(const char* resume_file); virtual ~Solver() {} + inline const SolverParameter& param() const { return param_; } inline shared_ptr > net() { return net_; } inline const vector > >& test_nets() { return test_nets_; } int iter() { return iter_; } + // Invoked at specific points during an iteration + class Callback { + protected: + virtual void on_start() = 0; + virtual void on_gradients_ready() = 0; + + template + friend class Solver; + }; + const vector& callbacks() const { return callbacks_; } + void add_callback(Callback* value) { + callbacks_.push_back(value); + } + protected: // Make and apply the update value for the current iteration. virtual void ApplyUpdate() = 0; @@ -62,10 +78,38 @@ class Solver { int current_step_; shared_ptr > net_; vector > > test_nets_; + vector callbacks_; + + // The root solver that holds root nets (actually containing shared layers) + // in data parallelism + const Solver* const root_solver_; DISABLE_COPY_AND_ASSIGN(Solver); }; +/** + * @brief Solver that only computes gradients, used as worker + * for multi-GPU training. + */ +template +class WorkerSolver : public Solver { + public: + explicit WorkerSolver(const SolverParameter& param, + const Solver* root_solver = NULL) + : Solver(param, root_solver) {} + + protected: + void ApplyUpdate() {} + void SnapshotSolverState(const string& model_filename) { + LOG(FATAL) << "Should not be called on worker solver."; + } + void RestoreSolverStateFromBinaryProto(const string& state_file) { + LOG(FATAL) << "Should not be called on worker solver."; + } + void RestoreSolverStateFromHDF5(const string& state_file) { + LOG(FATAL) << "Should not be called on worker solver."; + } +}; /** * @brief Optimizes the parameters of a Net using diff --git a/include/caffe/syncedmem.hpp b/include/caffe/syncedmem.hpp index 1b726de9564..62aadef498d 100644 --- a/include/caffe/syncedmem.hpp +++ b/include/caffe/syncedmem.hpp @@ -8,26 +8,29 @@ namespace caffe { -// Theoretically, CaffeMallocHost and CaffeFreeHost should simply call the -// cudaMallocHost and cudaFree functions in order to create pinned memory. -// However, those codes rely on the existence of a cuda GPU (I don't know -// why that is a must since allocating memory should not be accessing the -// GPU resource, but it just creates an error as of Cuda 5.0) and will cause -// problem when running on a machine without GPU. Thus, we simply define -// these two functions for safety and possible future change if the problem -// of calling cuda functions disappears in a future version. -// -// In practice, although we are creating unpinned memory here, as long as we -// are constantly accessing them the memory pages almost always stays in -// the physical memory (assuming we have large enough memory installed), and -// does not seem to create a memory bottleneck here. - +// If CUDA is available and in GPU mode, host memory will be allocated pinned, +// using cudaMallocHost. It avoids dynamic pinning for transfers (DMA). +// The improvement in performance seems negligible in the single GPU case, +// but might be more significant for parallel training. Most importantly, +// it improved stability for large models on many GPUs. inline void CaffeMallocHost(void** ptr, size_t size) { +#ifndef CPU_ONLY + if (Caffe::mode() == Caffe::GPU) { + CUDA_CHECK(cudaMallocHost(ptr, size)); + return; + } +#endif *ptr = malloc(size); CHECK(*ptr) << "host allocation of size " << size << " failed"; } inline void CaffeFreeHost(void* ptr) { +#ifndef CPU_ONLY + if (Caffe::mode() == Caffe::GPU) { + CUDA_CHECK(cudaFreeHost(ptr)); + return; + } +#endif free(ptr); } @@ -42,20 +45,25 @@ class SyncedMemory { public: SyncedMemory() : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED), - own_cpu_data_(false) {} + own_cpu_data_(false), own_gpu_data_(false), gpu_device_(-1) {} explicit SyncedMemory(size_t size) : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED), - own_cpu_data_(false) {} + own_cpu_data_(false), own_gpu_data_(false), gpu_device_(-1) {} ~SyncedMemory(); const void* cpu_data(); void set_cpu_data(void* data); const void* gpu_data(); + void set_gpu_data(void* data); void* mutable_cpu_data(); void* mutable_gpu_data(); enum SyncedHead { UNINITIALIZED, HEAD_AT_CPU, HEAD_AT_GPU, SYNCED }; SyncedHead head() { return head_; } size_t size() { return size_; } +#ifndef CPU_ONLY + void async_gpu_push(const cudaStream_t& stream); +#endif + private: void to_cpu(); void to_gpu(); @@ -64,6 +72,8 @@ class SyncedMemory { size_t size_; SyncedHead head_; bool own_cpu_data_; + bool own_gpu_data_; + int gpu_device_; DISABLE_COPY_AND_ASSIGN(SyncedMemory); }; // class SyncedMemory diff --git a/include/caffe/util/blocking_queue.hpp b/include/caffe/util/blocking_queue.hpp new file mode 100644 index 00000000000..955e12cc567 --- /dev/null +++ b/include/caffe/util/blocking_queue.hpp @@ -0,0 +1,47 @@ +#ifndef CAFFE_UTIL_BLOCKING_QUEUE_HPP_ +#define CAFFE_UTIL_BLOCKING_QUEUE_HPP_ + +#include +#include + +#include "caffe/common.hpp" + +namespace caffe { + +template +class BlockingQueue { + public: + explicit BlockingQueue(); + + void push(const T& t); + + bool try_pop(T* t); + + // This logs a message if the threads needs to be blocked + // useful for detecting e.g. when data feeding is too slow + T pop(const string& log_on_wait = ""); + + bool try_peek(T* t); + + // Return element without removing it + T peek(); + + size_t size() const; + + protected: + /** + Move synchronization fields out instead of including boost/thread.hpp + to avoid a boost/NVCC issues (#1009, #1010) on OSX. Also fails on + Linux CUDA 7.0.18. + */ + class sync; + + std::queue queue_; + shared_ptr sync_; + +DISABLE_COPY_AND_ASSIGN(BlockingQueue); +}; + +} // namespace caffe + +#endif diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp index af96cac40aa..7077f3789d7 100644 --- a/src/caffe/common.cpp +++ b/src/caffe/common.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -7,7 +8,15 @@ namespace caffe { -shared_ptr Caffe::singleton_; +// Make sure each thread can have different values. +static boost::thread_specific_ptr thread_instance_; + +Caffe& Caffe::Get() { + if (!thread_instance_.get()) { + thread_instance_.reset(new Caffe()); + } + return *(thread_instance_.get()); +} // random seeding int64_t cluster_seedgen(void) { @@ -42,7 +51,8 @@ void GlobalInit(int* pargc, char*** pargv) { #ifdef CPU_ONLY // CPU-only Caffe. Caffe::Caffe() - : random_generator_(), mode_(Caffe::CPU) { } + : random_generator_(), mode_(Caffe::CPU), + solver_count_(1), root_solver_(true) { } Caffe::~Caffe() { } @@ -86,7 +96,7 @@ void* Caffe::RNG::generator() { Caffe::Caffe() : cublas_handle_(NULL), curand_generator_(NULL), random_generator_(), - mode_(Caffe::CPU) { + mode_(Caffe::CPU), solver_count_(1), root_solver_(true) { // Try to create a cublas handler, and report an error if failed (but we will // keep the program running as one might just want to run CPU code). if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) { diff --git a/src/caffe/data_reader.cpp b/src/caffe/data_reader.cpp new file mode 100644 index 00000000000..16378203a88 --- /dev/null +++ b/src/caffe/data_reader.cpp @@ -0,0 +1,119 @@ +#include +#include +#include +#include + +#include "caffe/common.hpp" +#include "caffe/data_layers.hpp" +#include "caffe/data_reader.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +using boost::weak_ptr; + +map > DataReader::bodies_; +static boost::mutex bodies_mutex_; + +DataReader::DataReader(const LayerParameter& param) + : queue_pair_(new QueuePair( // + param.data_param().prefetch() * param.data_param().batch_size())) { + // Get or create a body + boost::mutex::scoped_lock lock(bodies_mutex_); + string key = source_key(param); + weak_ptr& weak = bodies_[key]; + body_ = weak.lock(); + if (!body_) { + body_.reset(new Body(param)); + bodies_[key] = weak_ptr(body_); + } + body_->new_queue_pairs_.push(queue_pair_); +} + +DataReader::~DataReader() { + string key = source_key(body_->param_); + body_.reset(); + boost::mutex::scoped_lock lock(bodies_mutex_); + if (bodies_[key].expired()) { + bodies_.erase(key); + } +} + +// + +DataReader::QueuePair::QueuePair(int size) { + // Initialize the free queue with requested number of datums + for (int i = 0; i < size; ++i) { + free_.push(new Datum()); + } +} + +DataReader::QueuePair::~QueuePair() { + Datum* datum; + while (free_.try_pop(&datum)) { + delete datum; + } + while (full_.try_pop(&datum)) { + delete datum; + } +} + +// + +DataReader::Body::Body(const LayerParameter& param) + : param_(param), + new_queue_pairs_() { + StartInternalThread(); +} + +DataReader::Body::~Body() { + StopInternalThread(); +} + +void DataReader::Body::InternalThreadEntry() { + shared_ptr db(db::GetDB(param_.data_param().backend())); + db->Open(param_.data_param().source(), db::READ); + shared_ptr cursor(db->NewCursor()); + vector > qps; + try { + int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1; + + // To ensure deterministic runs, only start running once all solvers + // are ready. But solvers need to peek on one item during initialization, + // so read one item, then wait for the next solver. + for (int i = 0; i < solver_count; ++i) { + shared_ptr qp(new_queue_pairs_.pop()); + read_one(cursor.get(), qp.get()); + qps.push_back(qp); + } + // Main loop + while (!must_stop()) { + for (int i = 0; i < solver_count; ++i) { + read_one(cursor.get(), qps[i].get()); + } + // Check no additional readers have been created. This can happen if + // more than one net is trained at a time per process, whether single + // or multi solver. It might also happen if two data layers have same + // name and same source. + CHECK_EQ(new_queue_pairs_.size(), 0); + } + } catch (boost::thread_interrupted&) { + // Interrupted exception is expected on shutdown + } +} + +void DataReader::Body::read_one(db::Cursor* cursor, QueuePair* qp) { + Datum* datum = qp->free_.pop(); + // TODO deserialize in-place instead of copy? + datum->ParseFromString(cursor->value()); + qp->full_.push(datum); + + // go to the next iter + cursor->Next(); + if (!cursor->valid()) { + DLOG(INFO) << "Restarting data prefetching from start."; + cursor->SeekToFirst(); + } +} + +} // namespace caffe diff --git a/src/caffe/data_transformer.cpp b/src/caffe/data_transformer.cpp index 22633922a01..4666d9bd881 100644 --- a/src/caffe/data_transformer.cpp +++ b/src/caffe/data_transformer.cpp @@ -19,7 +19,9 @@ DataTransformer::DataTransformer(const TransformationParameter& param, CHECK_EQ(param_.mean_value_size(), 0) << "Cannot specify mean_file and mean_value at the same time"; const string& mean_file = param.mean_file(); - LOG(INFO) << "Loading mean file from: " << mean_file; + if (Caffe::root_solver()) { + LOG(INFO) << "Loading mean file from: " << mean_file; + } BlobProto blob_proto; ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); data_mean_.FromProto(blob_proto); diff --git a/src/caffe/internal_thread.cpp b/src/caffe/internal_thread.cpp index c2d19d433b4..104884e0295 100644 --- a/src/caffe/internal_thread.cpp +++ b/src/caffe/internal_thread.cpp @@ -1,40 +1,66 @@ #include +#include + #include "caffe/internal_thread.hpp" +#include "caffe/util/math_functions.hpp" namespace caffe { InternalThread::~InternalThread() { - WaitForInternalThreadToExit(); + StopInternalThread(); } bool InternalThread::is_started() const { - return thread_.get() != NULL && thread_->joinable(); + return thread_ && thread_->joinable(); +} + +bool InternalThread::must_stop() { + return thread_ && thread_->interruption_requested(); } +void InternalThread::StartInternalThread() { + CHECK(!is_started()) << "Threads should persist and not be restarted."; + + int device = 0; +#ifndef CPU_ONLY + CUDA_CHECK(cudaGetDevice(&device)); +#endif + Caffe::Brew mode = Caffe::mode(); + int rand_seed = caffe_rng_rand(); + int solver_count = Caffe::solver_count(); + bool root_solver = Caffe::root_solver(); -bool InternalThread::StartInternalThread() { - if (!WaitForInternalThreadToExit()) { - return false; - } try { - thread_.reset( - new boost::thread(&InternalThread::InternalThreadEntry, this)); - } catch (...) { - return false; + thread_.reset(new boost::thread(&InternalThread::entry, this, device, mode, + rand_seed, solver_count, root_solver)); + } catch (std::exception& e) { + LOG(FATAL) << "Thread exception: " << e.what(); } - return true; } -/** Will not return until the internal thread has exited. */ -bool InternalThread::WaitForInternalThreadToExit() { +void InternalThread::entry(int device, Caffe::Brew mode, int rand_seed, + int solver_count, bool root_solver) { +#ifndef CPU_ONLY + CUDA_CHECK(cudaSetDevice(device)); +#endif + Caffe::set_mode(mode); + Caffe::set_random_seed(rand_seed); + Caffe::set_solver_count(solver_count); + Caffe::set_root_solver(root_solver); + + InternalThreadEntry(); +} + +void InternalThread::StopInternalThread() { if (is_started()) { + thread_->interrupt(); try { thread_->join(); - } catch (...) { - return false; + } catch (boost::thread_interrupted&) { + } catch (std::exception& e) { + LOG(FATAL) << "Thread exception: " << e.what(); } } - return true; } } // namespace caffe diff --git a/src/caffe/layer.cpp b/src/caffe/layer.cpp new file mode 100644 index 00000000000..3b9128986ae --- /dev/null +++ b/src/caffe/layer.cpp @@ -0,0 +1,27 @@ +#include +#include "caffe/layer.hpp" + +namespace caffe { + +template +void Layer::InitMutex() { + forward_mutex_.reset(new boost::mutex()); +} + +template +void Layer::Lock() { + if (IsShared()) { + forward_mutex_->lock(); + } +} + +template +void Layer::Unlock() { + if (IsShared()) { + forward_mutex_->unlock(); + } +} + +INSTANTIATE_CLASS(Layer); + +} // namespace caffe diff --git a/src/caffe/layers/base_data_layer.cpp b/src/caffe/layers/base_data_layer.cpp index 26a1118282f..20f76f62994 100644 --- a/src/caffe/layers/base_data_layer.cpp +++ b/src/caffe/layers/base_data_layer.cpp @@ -1,7 +1,9 @@ +#include #include #include #include "caffe/data_layers.hpp" +#include "caffe/net.hpp" #include "caffe/util/io.hpp" namespace caffe { @@ -27,56 +29,92 @@ void BaseDataLayer::LayerSetUp(const vector*>& bottom, DataLayerSetUp(bottom, top); } +template +BasePrefetchingDataLayer::BasePrefetchingDataLayer( + const LayerParameter& param) + : BaseDataLayer(param), + prefetch_free_(), prefetch_full_() { + for (int i = 0; i < PREFETCH_COUNT; ++i) { + prefetch_free_.push(&prefetch_[i]); + } +} + template void BasePrefetchingDataLayer::LayerSetUp( const vector*>& bottom, const vector*>& top) { BaseDataLayer::LayerSetUp(bottom, top); - // Now, start the prefetch thread. Before calling prefetch, we make two - // cpu_data calls so that the prefetch thread does not accidentally make - // simultaneous cudaMalloc calls when the main thread is running. In some - // GPUs this seems to cause failures if we do not so. - this->prefetch_data_.mutable_cpu_data(); - if (this->output_labels_) { - this->prefetch_label_.mutable_cpu_data(); + // Before starting the prefetch thread, we make cpu_data and gpu_data + // calls so that the prefetch thread does not accidentally make simultaneous + // cudaMalloc calls when the main thread is running. In some GPUs this + // seems to cause failures if we do not so. + for (int i = 0; i < PREFETCH_COUNT; ++i) { + prefetch_[i].data_.mutable_cpu_data(); + if (this->output_labels_) { + prefetch_[i].label_.mutable_cpu_data(); + } + } +#ifndef CPU_ONLY + if (Caffe::mode() == Caffe::GPU) { + for (int i = 0; i < PREFETCH_COUNT; ++i) { + prefetch_[i].data_.mutable_gpu_data(); + if (this->output_labels_) { + prefetch_[i].label_.mutable_gpu_data(); + } + } } +#endif DLOG(INFO) << "Initializing prefetch"; - this->CreatePrefetchThread(); + this->data_transformer_->InitRand(); + StartInternalThread(); DLOG(INFO) << "Prefetch initialized."; } template -void BasePrefetchingDataLayer::CreatePrefetchThread() { - this->data_transformer_->InitRand(); - CHECK(StartInternalThread()) << "Thread execution failed"; -} +void BasePrefetchingDataLayer::InternalThreadEntry() { +#ifndef CPU_ONLY + cudaStream_t stream; + if (Caffe::mode() == Caffe::GPU) { + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + } +#endif -template -void BasePrefetchingDataLayer::JoinPrefetchThread() { - CHECK(WaitForInternalThreadToExit()) << "Thread joining failed"; + try { + while (!must_stop()) { + Batch* batch = prefetch_free_.pop(); + load_batch(batch); +#ifndef CPU_ONLY + if (Caffe::mode() == Caffe::GPU) { + batch->data_.data().get()->async_gpu_push(stream); + cudaStreamSynchronize(stream); + } +#endif + prefetch_full_.push(batch); + } + } catch (boost::thread_interrupted&) { + // Interrupted exception is expected on shutdown + } } template void BasePrefetchingDataLayer::Forward_cpu( const vector*>& bottom, const vector*>& top) { - // First, join the thread - JoinPrefetchThread(); - DLOG(INFO) << "Thread joined"; + Batch* batch = prefetch_full_.pop("Data layer prefetch queue empty"); // Reshape to loaded data. - top[0]->ReshapeLike(prefetch_data_); + top[0]->Reshape(batch->data_.num(), batch->data_.channels(), + batch->data_.height(), batch->data_.width()); // Copy the data - caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(), + caffe_copy(batch->data_.count(), batch->data_.cpu_data(), top[0]->mutable_cpu_data()); DLOG(INFO) << "Prefetch copied"; if (this->output_labels_) { // Reshape to loaded labels. - top[1]->ReshapeLike(prefetch_label_); + top[1]->ReshapeLike(batch->label_); // Copy the labels. - caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(), - top[1]->mutable_cpu_data()); + caffe_copy(batch->label_.count(), batch->label_.cpu_data(), + top[1]->mutable_cpu_data()); } - // Start a new prefetch thread - DLOG(INFO) << "CreatePrefetchThread"; - CreatePrefetchThread(); + + prefetch_free_.push(batch); } #ifdef CPU_ONLY diff --git a/src/caffe/layers/base_data_layer.cu b/src/caffe/layers/base_data_layer.cu index 9335a5bc9a9..56439bc506a 100644 --- a/src/caffe/layers/base_data_layer.cu +++ b/src/caffe/layers/base_data_layer.cu @@ -7,22 +7,21 @@ namespace caffe { template void BasePrefetchingDataLayer::Forward_gpu( const vector*>& bottom, const vector*>& top) { - // First, join the thread - JoinPrefetchThread(); + Batch* batch = prefetch_full_.pop("Data layer prefetch queue empty"); // Reshape to loaded data. - top[0]->ReshapeLike(this->prefetch_data_); + top[0]->ReshapeLike(batch->data_); // Copy the data - caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(), + caffe_copy(batch->data_.count(), batch->data_.gpu_data(), top[0]->mutable_gpu_data()); if (this->output_labels_) { // Reshape to loaded labels. - top[1]->ReshapeLike(prefetch_label_); + top[1]->ReshapeLike(batch->label_); // Copy the labels. - caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(), + caffe_copy(batch->label_.count(), batch->label_.gpu_data(), top[1]->mutable_gpu_data()); } - // Start a new prefetch thread - CreatePrefetchThread(); + + prefetch_free_.push(batch); } INSTANTIATE_LAYER_GPU_FORWARD(BasePrefetchingDataLayer); diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 161a75e0c8c..0932d9feff3 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -11,93 +11,85 @@ #include "caffe/proto/caffe.pb.h" #include "caffe/util/benchmark.hpp" #include "caffe/util/io.hpp" -#include "caffe/util/math_functions.hpp" -#include "caffe/util/rng.hpp" namespace caffe { template -DataLayer::~DataLayer() { - this->JoinPrefetchThread(); +DataLayer::DataLayer(const LayerParameter& param) + : BasePrefetchingDataLayer(param), + reader_(param) { +} + +template +DataLayer::~DataLayer() { + this->StopInternalThread(); } template void DataLayer::DataLayerSetUp(const vector*>& bottom, const vector*>& top) { - // Initialize DB - db_.reset(db::GetDB(this->layer_param_.data_param().backend())); - db_->Open(this->layer_param_.data_param().source(), db::READ); - cursor_.reset(db_->NewCursor()); + const int batch_size = this->layer_param_.data_param().batch_size(); + // Read a data point, and use it to initialize the top blob. + Datum& datum = *(reader_.full().peek()); - // Check if we should randomly skip a few data points - if (this->layer_param_.data_param().rand_skip()) { - unsigned int skip = caffe_rng_rand() % - this->layer_param_.data_param().rand_skip(); - LOG(INFO) << "Skipping first " << skip << " data points."; - while (skip-- > 0) { - cursor_->Next(); - } - } - // Read a data point, to initialize the prefetch and top blobs. - Datum datum; - datum.ParseFromString(cursor_->value()); // Use data_transformer to infer the expected blob shape from datum. vector top_shape = this->data_transformer_->InferBlobShape(datum); this->transformed_data_.Reshape(top_shape); // Reshape top[0] and prefetch_data according to the batch_size. - top_shape[0] = this->layer_param_.data_param().batch_size(); - this->prefetch_data_.Reshape(top_shape); - top[0]->ReshapeLike(this->prefetch_data_); - + top_shape[0] = batch_size; + top[0]->Reshape(top_shape); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].data_.Reshape(top_shape); + } LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," << top[0]->width(); // label if (this->output_labels_) { - vector label_shape(1, this->layer_param_.data_param().batch_size()); + vector label_shape(1, batch_size); top[1]->Reshape(label_shape); - this->prefetch_label_.Reshape(label_shape); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].label_.Reshape(label_shape); + } } } -// This function is used to create a thread that prefetches the data. -template -void DataLayer::InternalThreadEntry() { +// This function is called on prefetch thread +template +void DataLayer::load_batch(Batch* batch) { CPUTimer batch_timer; batch_timer.Start(); double read_time = 0; double trans_time = 0; CPUTimer timer; - CHECK(this->prefetch_data_.count()); + CHECK(batch->data_.count()); CHECK(this->transformed_data_.count()); // Reshape according to the first datum of each batch // on single input batches allows for inputs of varying dimension. const int batch_size = this->layer_param_.data_param().batch_size(); - Datum datum; - datum.ParseFromString(cursor_->value()); + Datum& datum = *(reader_.full().peek()); // Use data_transformer to infer the expected blob shape from datum. vector top_shape = this->data_transformer_->InferBlobShape(datum); this->transformed_data_.Reshape(top_shape); - // Reshape prefetch_data according to the batch_size. + // Reshape batch according to the batch_size. top_shape[0] = batch_size; - this->prefetch_data_.Reshape(top_shape); + batch->data_.Reshape(top_shape); - Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); + Dtype* top_data = batch->data_.mutable_cpu_data(); Dtype* top_label = NULL; // suppress warnings about uninitialized variables if (this->output_labels_) { - top_label = this->prefetch_label_.mutable_cpu_data(); + top_label = batch->label_.mutable_cpu_data(); } - timer.Start(); for (int item_id = 0; item_id < batch_size; ++item_id) { + timer.Start(); // get a datum - Datum datum; - datum.ParseFromString(cursor_->value()); + Datum& datum = *(reader_.full().pop("Waiting for data")); read_time += timer.MicroSeconds(); timer.Start(); // Apply data transformations (mirror, scale, crop...) - int offset = this->prefetch_data_.offset(item_id); + int offset = batch->data_.offset(item_id); this->transformed_data_.set_cpu_data(top_data + offset); this->data_transformer_->Transform(datum, &(this->transformed_data_)); // Copy label. @@ -105,13 +97,8 @@ void DataLayer::InternalThreadEntry() { top_label[item_id] = datum.label(); } trans_time += timer.MicroSeconds(); - timer.Start(); - // go to the next item. - cursor_->Next(); - if (!cursor_->valid()) { - DLOG(INFO) << "Restarting data prefetching from start."; - cursor_->SeekToFirst(); - } + + reader_.free().push(const_cast(&datum)); } timer.Stop(); batch_timer.Stop(); diff --git a/src/caffe/layers/image_data_layer.cpp b/src/caffe/layers/image_data_layer.cpp index dcc53348304..223ba3a75ca 100644 --- a/src/caffe/layers/image_data_layer.cpp +++ b/src/caffe/layers/image_data_layer.cpp @@ -17,7 +17,7 @@ namespace caffe { template ImageDataLayer::~ImageDataLayer() { - this->JoinPrefetchThread(); + this->StopInternalThread(); } template @@ -70,8 +70,10 @@ void ImageDataLayer::DataLayerSetUp(const vector*>& bottom, const int batch_size = this->layer_param_.image_data_param().batch_size(); CHECK_GT(batch_size, 0) << "Positive batch size required"; top_shape[0] = batch_size; - this->prefetch_data_.Reshape(top_shape); - top[0]->ReshapeLike(this->prefetch_data_); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].data_.Reshape(top_shape); + } + top[0]->Reshape(top_shape); LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," @@ -79,7 +81,9 @@ void ImageDataLayer::DataLayerSetUp(const vector*>& bottom, // label vector label_shape(1, batch_size); top[1]->Reshape(label_shape); - this->prefetch_label_.Reshape(label_shape); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].label_.Reshape(label_shape); + } } template @@ -89,15 +93,15 @@ void ImageDataLayer::ShuffleImages() { shuffle(lines_.begin(), lines_.end(), prefetch_rng); } -// This function is used to create a thread that prefetches the data. +// This function is called on prefetch thread template -void ImageDataLayer::InternalThreadEntry() { +void ImageDataLayer::load_batch(Batch* batch) { CPUTimer batch_timer; batch_timer.Start(); double read_time = 0; double trans_time = 0; CPUTimer timer; - CHECK(this->prefetch_data_.count()); + CHECK(batch->data_.count()); CHECK(this->transformed_data_.count()); ImageDataParameter image_data_param = this->layer_param_.image_data_param(); const int batch_size = image_data_param.batch_size(); @@ -114,12 +118,12 @@ void ImageDataLayer::InternalThreadEntry() { // Use data_transformer to infer the expected blob shape from a cv_img. vector top_shape = this->data_transformer_->InferBlobShape(cv_img); this->transformed_data_.Reshape(top_shape); - // Reshape prefetch_data according to the batch_size. + // Reshape batch according to the batch_size. top_shape[0] = batch_size; - this->prefetch_data_.Reshape(top_shape); + batch->data_.Reshape(top_shape); - Dtype* prefetch_data = this->prefetch_data_.mutable_cpu_data(); - Dtype* prefetch_label = this->prefetch_label_.mutable_cpu_data(); + Dtype* prefetch_data = batch->data_.mutable_cpu_data(); + Dtype* prefetch_label = batch->label_.mutable_cpu_data(); // datum scales const int lines_size = lines_.size(); @@ -133,7 +137,7 @@ void ImageDataLayer::InternalThreadEntry() { read_time += timer.MicroSeconds(); timer.Start(); // Apply transformations (mirror, crop...) to the image - int offset = this->prefetch_data_.offset(item_id); + int offset = batch->data_.offset(item_id); this->transformed_data_.set_cpu_data(prefetch_data + offset); this->data_transformer_->Transform(cv_img, &(this->transformed_data_)); trans_time += timer.MicroSeconds(); diff --git a/src/caffe/layers/window_data_layer.cpp b/src/caffe/layers/window_data_layer.cpp index c127d56bc46..f637f2ec6d4 100644 --- a/src/caffe/layers/window_data_layer.cpp +++ b/src/caffe/layers/window_data_layer.cpp @@ -27,7 +27,7 @@ namespace caffe { template WindowDataLayer::~WindowDataLayer() { - this->JoinPrefetchThread(); + this->StopInternalThread(); } template @@ -171,7 +171,9 @@ void WindowDataLayer::DataLayerSetUp(const vector*>& bottom, CHECK_GT(crop_size, 0); const int batch_size = this->layer_param_.window_data_param().batch_size(); top[0]->Reshape(batch_size, channels, crop_size, crop_size); - this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) + this->prefetch_[i].data_.Reshape( + batch_size, channels, crop_size, crop_size); LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," @@ -179,7 +181,9 @@ void WindowDataLayer::DataLayerSetUp(const vector*>& bottom, // label vector label_shape(1, batch_size); top[1]->Reshape(label_shape); - this->prefetch_label_.Reshape(label_shape); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].label_.Reshape(label_shape); + } // data mean has_mean_file_ = this->transform_param_.has_mean_file(); @@ -217,9 +221,9 @@ unsigned int WindowDataLayer::PrefetchRand() { return (*prefetch_rng)(); } -// Thread fetching the data +// This function is called on prefetch thread template -void WindowDataLayer::InternalThreadEntry() { +void WindowDataLayer::load_batch(Batch* batch) { // At each iteration, sample N windows where N*p are foreground (object) // windows and N*(1-p) are background (non-object) windows CPUTimer batch_timer; @@ -227,8 +231,8 @@ void WindowDataLayer::InternalThreadEntry() { double read_time = 0; double trans_time = 0; CPUTimer timer; - Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); - Dtype* top_label = this->prefetch_label_.mutable_cpu_data(); + Dtype* top_data = batch->data_.mutable_cpu_data(); + Dtype* top_label = batch->label_.mutable_cpu_data(); const Dtype scale = this->layer_param_.window_data_param().scale(); const int batch_size = this->layer_param_.window_data_param().batch_size(); const int context_pad = this->layer_param_.window_data_param().context_pad(); @@ -252,7 +256,7 @@ void WindowDataLayer::InternalThreadEntry() { bool use_square = (crop_mode == "square") ? true : false; // zero out batch - caffe_set(this->prefetch_data_.count(), Dtype(0), top_data); + caffe_set(batch->data_.count(), Dtype(0), top_data); const int num_fg = static_cast(static_cast(batch_size) * fg_fraction); diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 0e5ed804b73..7f5bdf7e2ba 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -10,6 +10,7 @@ #include "caffe/common.hpp" #include "caffe/layer.hpp" #include "caffe/net.hpp" +#include "caffe/parallel.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/hdf5.hpp" #include "caffe/util/insert_splits.hpp" @@ -21,12 +22,14 @@ namespace caffe { template -Net::Net(const NetParameter& param) { +Net::Net(const NetParameter& param, const Net* root_net) + : root_net_(root_net) { Init(param); } template -Net::Net(const string& param_file, Phase phase) { +Net::Net(const string& param_file, Phase phase, const Net* root_net) + : root_net_(root_net) { NetParameter param; ReadNetParamsFromTextFileOrDie(param_file, ¶m); param.mutable_state()->set_phase(phase); @@ -35,14 +38,18 @@ Net::Net(const string& param_file, Phase phase) { template void Net::Init(const NetParameter& in_param) { + CHECK(Caffe::root_solver() || root_net_) + << "root_net_ needs to be set for all non-root solvers"; // Set phase from the state. phase_ = in_param.state().phase(); // Filter layers based on their include/exclude rules and // the current NetState. NetParameter filtered_param; FilterNet(in_param, &filtered_param); - LOG(INFO) << "Initializing net from parameters: " << std::endl - << filtered_param.DebugString(); + if (Caffe::root_solver()) { + LOG(INFO) << "Initializing net from parameters: " << std::endl + << filtered_param.DebugString(); + } // Create a copy of filtered_param with splits added where necessary. NetParameter param; InsertSplits(filtered_param, ¶m); @@ -66,7 +73,8 @@ void Net::Init(const NetParameter& in_param) { const int layer_id = -1; // inputs have fake layer ID -1 AppendTop(param, layer_id, input_id, &available_blobs, &blob_name_to_idx); } - DLOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + DLOG_IF(INFO, Caffe::root_solver()) + << "Memory required for data: " << memory_used_ * sizeof(Dtype); // For each layer, set up its input and output bottom_vecs_.resize(param.layer_size()); top_vecs_.resize(param.layer_size()); @@ -75,6 +83,9 @@ void Net::Init(const NetParameter& in_param) { top_id_vecs_.resize(param.layer_size()); bottom_need_backward_.resize(param.layer_size()); for (int layer_id = 0; layer_id < param.layer_size(); ++layer_id) { + // For non-root solvers, whether this layer is shared from root_net_. + bool share_from_root = !Caffe::root_solver() + && root_net_->layers_[layer_id]->ShareInParallel(); // Inherit phase from net if unset. if (!param.layer(layer_id).has_phase()) { param.mutable_layer(layer_id)->set_phase(phase_); @@ -87,9 +98,17 @@ void Net::Init(const NetParameter& in_param) { << "propagate_down param must be specified " << "either 0 or bottom_size times "; } - layers_.push_back(LayerRegistry::CreateLayer(layer_param)); + if (share_from_root) { + LOG(INFO) << "Sharing layer " << layer_param.name() << " from root net"; + layers_.push_back(root_net_->layers_[layer_id]); + layers_[layer_id]->SetShared(true); + } else { + layers_.push_back(LayerRegistry::CreateLayer(layer_param)); + } layer_names_.push_back(layer_param.name()); - LOG(INFO) << "Creating Layer " << layer_param.name(); + if (Caffe::root_solver()) { + LOG(INFO) << "Creating Layer " << layer_param.name(); + } bool need_backward = false; // Figure out this layer's input and output @@ -119,20 +138,42 @@ void Net::Init(const NetParameter& in_param) { } } // After this layer is connected, set it up. - LOG(INFO) << "Setting up " << layer_names_[layer_id]; - layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]); + if (share_from_root) { + // Set up size of top blobs using root_net_ + const vector*>& base_top = root_net_->top_vecs_[layer_id]; + const vector*>& this_top = this->top_vecs_[layer_id]; + for (int top_id = 0; top_id < base_top.size(); ++top_id) { + this_top[top_id]->ReshapeLike(*base_top[top_id]); + LOG(INFO) << "Created top blob " << top_id << " (shape: " + << this_top[top_id]->shape_string() << ") for shared layer " + << layer_param.name(); + } + } else { + layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]); + } + if (Caffe::root_solver()) { + LOG(INFO) << "Setting up " << layer_names_[layer_id]; + } for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) { if (blob_loss_weights_.size() <= top_id_vecs_[layer_id][top_id]) { blob_loss_weights_.resize(top_id_vecs_[layer_id][top_id] + 1, Dtype(0)); } blob_loss_weights_[top_id_vecs_[layer_id][top_id]] = layer->loss(top_id); - LOG(INFO) << "Top shape: " << top_vecs_[layer_id][top_id]->shape_string(); + if (Caffe::root_solver()) { + LOG(INFO) << "Top shape: " + << top_vecs_[layer_id][top_id]->shape_string(); + } if (layer->loss(top_id)) { - LOG(INFO) << " with loss weight " << layer->loss(top_id); + if (Caffe::root_solver()) { + LOG(INFO) << " with loss weight " << layer->loss(top_id); + } } memory_used_ += top_vecs_[layer_id][top_id]->count(); } - DLOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + if (Caffe::root_solver()) { + DLOG(INFO) << "Memory required for data: " + << memory_used_ * sizeof(Dtype); + } const int param_size = layer_param.param_size(); const int num_param_blobs = layers_[layer_id]->blobs().size(); CHECK_LE(param_size, num_param_blobs) @@ -191,10 +232,14 @@ void Net::Init(const NetParameter& in_param) { } if (!layer_contributes_loss) { layer_need_backward_[layer_id] = false; } if (layer_need_backward_[layer_id]) { - LOG(INFO) << layer_names_[layer_id] << " needs backward computation."; + if (Caffe::root_solver()) { + LOG(INFO) << layer_names_[layer_id] << " needs backward computation."; + } } else { - LOG(INFO) << layer_names_[layer_id] - << " does not need backward computation."; + if (Caffe::root_solver()) { + LOG(INFO) << layer_names_[layer_id] + << " does not need backward computation."; + } } for (int bottom_id = 0; bottom_id < bottom_vecs_[layer_id].size(); ++bottom_id) { @@ -234,7 +279,9 @@ void Net::Init(const NetParameter& in_param) { // In the end, all remaining blobs are considered output blobs. for (set::iterator it = available_blobs.begin(); it != available_blobs.end(); ++it) { - LOG(INFO) << "This network produces output " << *it; + if (Caffe::root_solver()) { + LOG(INFO) << "This network produces output " << *it; + } net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get()); net_output_blob_indices_.push_back(blob_name_to_idx[*it]); } @@ -246,8 +293,10 @@ void Net::Init(const NetParameter& in_param) { } ShareWeights(); debug_info_ = param.debug_info(); - LOG(INFO) << "Network initialization done."; - LOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + if (Caffe::root_solver()) { + LOG(INFO) << "Network initialization done."; + LOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + } } template @@ -286,27 +335,33 @@ bool Net::StateMeetsRule(const NetState& state, // Check whether the rule is broken due to phase. if (rule.has_phase()) { if (rule.phase() != state.phase()) { - LOG(INFO) << "The NetState phase (" << state.phase() - << ") differed from the phase (" << rule.phase() - << ") specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState phase (" << state.phase() + << ") differed from the phase (" << rule.phase() + << ") specified by a rule in layer " << layer_name; + } return false; } } // Check whether the rule is broken due to min level. if (rule.has_min_level()) { if (state.level() < rule.min_level()) { - LOG(INFO) << "The NetState level (" << state.level() - << ") is above the min_level (" << rule.min_level() - << ") specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState level (" << state.level() + << ") is above the min_level (" << rule.min_level() + << ") specified by a rule in layer " << layer_name; + } return false; } } // Check whether the rule is broken due to max level. if (rule.has_max_level()) { if (state.level() > rule.max_level()) { - LOG(INFO) << "The NetState level (" << state.level() - << ") is above the max_level (" << rule.max_level() - << ") specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState level (" << state.level() + << ") is above the max_level (" << rule.max_level() + << ") specified by a rule in layer " << layer_name; + } return false; } } @@ -319,8 +374,10 @@ bool Net::StateMeetsRule(const NetState& state, if (rule.stage(i) == state.stage(j)) { has_stage = true; } } if (!has_stage) { - LOG(INFO) << "The NetState did not contain stage '" << rule.stage(i) - << "' specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState did not contain stage '" << rule.stage(i) + << "' specified by a rule in layer " << layer_name; + } return false; } } @@ -333,8 +390,10 @@ bool Net::StateMeetsRule(const NetState& state, if (rule.not_stage(i) == state.stage(j)) { has_stage = true; } } if (has_stage) { - LOG(INFO) << "The NetState contained a not_stage '" << rule.not_stage(i) - << "' specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState contained a not_stage '" << rule.not_stage(i) + << "' specified by a rule in layer " << layer_name; + } return false; } } @@ -356,7 +415,9 @@ void Net::AppendTop(const NetParameter& param, const int layer_id, if (blob_name_to_idx && layer_param && layer_param->bottom_size() > top_id && blob_name == layer_param->bottom(top_id)) { // In-place computation - LOG(INFO) << layer_param->name() << " -> " << blob_name << " (in-place)"; + if (Caffe::root_solver()) { + LOG(INFO) << layer_param->name() << " -> " << blob_name << " (in-place)"; + } top_vecs_[layer_id].push_back(blobs_[(*blob_name_to_idx)[blob_name]].get()); top_id_vecs_[layer_id].push_back((*blob_name_to_idx)[blob_name]); } else if (blob_name_to_idx && @@ -366,10 +427,12 @@ void Net::AppendTop(const NetParameter& param, const int layer_id, LOG(FATAL) << "Duplicate blobs produced by multiple sources."; } else { // Normal output. - if (layer_param) { - LOG(INFO) << layer_param->name() << " -> " << blob_name; - } else { - LOG(INFO) << "Input " << top_id << " -> " << blob_name; + if (Caffe::root_solver()) { + if (layer_param) { + LOG(INFO) << layer_param->name() << " -> " << blob_name; + } else { + LOG(INFO) << "Input " << top_id << " -> " << blob_name; + } } shared_ptr > blob_pointer(new Blob()); const int blob_id = blobs_.size(); @@ -409,7 +472,9 @@ int Net::AppendBottom(const NetParameter& param, const int layer_id, << " (at index " << bottom_id << ") to layer " << layer_id; } const int blob_id = (*blob_name_to_idx)[blob_name]; - LOG(INFO) << layer_names_[layer_id] << " <- " << blob_name; + if (Caffe::root_solver()) { + LOG(INFO) << layer_names_[layer_id] << " <- " << blob_name; + } bottom_vecs_[layer_id].push_back(blobs_[blob_id].get()); bottom_id_vecs_[layer_id].push_back(blob_id); available_blobs->erase(blob_name); @@ -468,9 +533,10 @@ void Net::AppendParam(const NetParameter& param, const int layer_id, param_layer_indices_[owner_net_param_id]; const int owner_layer_id = owner_index.first; const int owner_param_id = owner_index.second; - LOG(INFO) << "Sharing parameters '" << param_name << "' owned by " - << "layer '" << layer_names_[owner_layer_id] << "', param " - << "index " << owner_param_id; + LOG_IF(INFO, Caffe::root_solver()) << "Sharing parameters '" << param_name + << "' owned by " + << "layer '" << layer_names_[owner_layer_id] << "', param " + << "index " << owner_param_id; Blob* this_blob = layers_[layer_id]->blobs()[param_id].get(); Blob* owner_blob = layers_[owner_layer_id]->blobs()[owner_param_id].get(); @@ -595,8 +661,10 @@ void Net::InputDebugInfo(const int input_id) { const Blob& blob = *net_input_blobs_[input_id]; const string& blob_name = blob_names_[net_input_blob_indices_[input_id]]; const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG(INFO) << " [Forward] " - << "Input " << blob_name << " data: " << data_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Forward] " + << "Input " << blob_name << " data: " << data_abs_val_mean; + } } template @@ -605,9 +673,12 @@ void Net::ForwardDebugInfo(const int layer_id) { const Blob& blob = *top_vecs_[layer_id][top_id]; const string& blob_name = blob_names_[top_id_vecs_[layer_id][top_id]]; const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG(INFO) << " [Forward] " - << "Layer " << layer_names_[layer_id] << ", top blob " << blob_name - << " data: " << data_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Forward] " + << "Layer " << layer_names_[layer_id] + << ", top blob " << blob_name + << " data: " << data_abs_val_mean; + } } for (int param_id = 0; param_id < layers_[layer_id]->blobs().size(); ++param_id) { @@ -615,9 +686,12 @@ void Net::ForwardDebugInfo(const int layer_id) { const int net_param_id = param_id_vecs_[layer_id][param_id]; const string& blob_name = param_display_names_[net_param_id]; const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG(INFO) << " [Forward] " - << "Layer " << layer_names_[layer_id] << ", param blob " << blob_name - << " data: " << data_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Forward] " + << "Layer " << layer_names_[layer_id] + << ", param blob " << blob_name + << " data: " << data_abs_val_mean; + } } } @@ -629,18 +703,24 @@ void Net::BackwardDebugInfo(const int layer_id) { const Blob& blob = *bottom_vec[bottom_id]; const string& blob_name = blob_names_[bottom_id_vecs_[layer_id][bottom_id]]; const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count(); - LOG(INFO) << " [Backward] " - << "Layer " << layer_names_[layer_id] << ", bottom blob " << blob_name - << " diff: " << diff_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Backward] " + << "Layer " << layer_names_[layer_id] + << ", bottom blob " << blob_name + << " diff: " << diff_abs_val_mean; + } } for (int param_id = 0; param_id < layers_[layer_id]->blobs().size(); ++param_id) { if (!layers_[layer_id]->param_propagate_down(param_id)) { continue; } const Blob& blob = *layers_[layer_id]->blobs()[param_id]; const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count(); - LOG(INFO) << " [Backward] " - << "Layer " << layer_names_[layer_id] << ", param blob " << param_id - << " diff: " << diff_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Backward] " + << "Layer " << layer_names_[layer_id] + << ", param blob " << param_id + << " diff: " << diff_abs_val_mean; + } } } @@ -653,17 +733,22 @@ void Net::UpdateDebugInfo(const int param_id) { const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count(); if (param_owner < 0) { const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG(INFO) << " [Update] Layer " << layer_name - << ", param " << param_display_name - << " data: " << data_abs_val_mean << "; diff: " << diff_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Update] Layer " << layer_name + << ", param " << param_display_name + << " data: " << data_abs_val_mean + << "; diff: " << diff_abs_val_mean; + } } else { const string& owner_layer_name = layer_names_[param_layer_indices_[param_owner].first]; - LOG(INFO) << " [Update] Layer " << layer_name - << ", param blob " << param_display_name - << " (owned by layer " << owner_layer_name << ", " - << "param " << param_display_names_[param_owners_[param_id]] << ")" - << " diff: " << diff_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Update] Layer " << layer_name + << ", param blob " << param_display_name + << " (owned by layer " << owner_layer_name << ", " << "param " + << param_display_names_[param_owners_[param_id]] << ")" + << " diff: " << diff_abs_val_mean; + } } } @@ -720,8 +805,8 @@ void Net::Backward() { const Dtype l2norm_data = std::sqrt(sumsq_data); const Dtype l2norm_diff = std::sqrt(sumsq_diff); LOG(ERROR) << " [Backward] All net params (data, diff): " - << "L1 norm = (" << asum_data << ", " << asum_diff << "); " - << "L2 norm = (" << l2norm_data << ", " << l2norm_diff << ")"; + << "L1 norm = (" << asum_data << ", " << asum_diff << "); " + << "L2 norm = (" << l2norm_data << ", " << l2norm_diff << ")"; } } diff --git a/src/caffe/parallel.cpp b/src/caffe/parallel.cpp new file mode 100644 index 00000000000..6e7d802bb99 --- /dev/null +++ b/src/caffe/parallel.cpp @@ -0,0 +1,438 @@ +#ifndef CPU_ONLY +#include +#endif +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "boost/thread.hpp" +#include "caffe/caffe.hpp" +#include "caffe/parallel.hpp" + +namespace caffe { + +enum Op { + copy, + replace_cpu, + replace_gpu, + replace_cpu_diff, + replace_gpu_diff +}; + +template +static void apply_buffers(const vector*>& blobs, + Dtype* buffer, size_t total_size, Op op) { + Dtype* ptr = buffer; + for (int i = 0; i < blobs.size(); ++i) { + int size = blobs[i]->count(); + switch (op) { + case copy: { + // Init buffer to current values of blobs + caffe_copy(size, + reinterpret_cast(blobs[i]->data()->cpu_data()), + ptr); + break; + } + case replace_cpu: + blobs[i]->data()->set_cpu_data(ptr); + break; + case replace_gpu: + blobs[i]->data()->set_gpu_data(ptr); + break; + case replace_cpu_diff: + blobs[i]->diff()->set_cpu_data(ptr); + break; + case replace_gpu_diff: + blobs[i]->diff()->set_gpu_data(ptr); + break; + } + ptr += size; + } + CHECK_EQ(total_size, ptr - buffer); +} + +// Buffer size necessary to store given blobs +template +static size_t total_size(const vector*>& params) { + size_t size = 0; + for (int i = 0; i < params.size(); ++i) + size += params[i]->count(); + return size; +} + +template +Params::Params(shared_ptr > root_solver) + : size_(total_size(root_solver->net()->learnable_params())), + data_(), + diff_() { +} + +template +GPUParams::GPUParams(shared_ptr > root_solver, int device) + : Params(root_solver) { +#ifndef CPU_ONLY + int initial_device; + CUDA_CHECK(cudaGetDevice(&initial_device)); + + // Allocate device buffers + CUDA_CHECK(cudaSetDevice(device)); + CUDA_CHECK(cudaMalloc(&data_, size_ * sizeof(Dtype))); + + // Copy blob values + const vector*>& net = + root_solver->net()->learnable_params(); + apply_buffers(net, data_, size_, copy); + + CUDA_CHECK(cudaMalloc(&diff_, size_ * sizeof(Dtype))); + caffe_gpu_set(size_, Dtype(0), diff_); + + CUDA_CHECK(cudaSetDevice(initial_device)); +#else + NO_GPU; +#endif +} + +template +GPUParams::~GPUParams() { +#ifndef CPU_ONLY + CUDA_CHECK(cudaFree(data_)); + CUDA_CHECK(cudaFree(diff_)); +#endif +} + +template +void GPUParams::configure(Solver* solver) const { + const vector*>& net = + solver->net()->learnable_params(); + apply_buffers(net, data_, size_, replace_gpu); + apply_buffers(net, diff_, size_, replace_gpu_diff); +} + +void DevicePair::compute(const vector devices, vector* pairs) { +#ifndef CPU_ONLY + vector remaining(devices); + + // Depth for reduction tree + int remaining_depth = static_cast(ceil(log2(remaining.size()))); + + // Group GPUs by board + for (int d = 0; d < remaining_depth; ++d) { + for (int i = 0; i < remaining.size(); ++i) { + for (int j = i + 1; j < remaining.size(); ++j) { + cudaDeviceProp a, b; + CUDA_CHECK(cudaGetDeviceProperties(&a, remaining[i])); + CUDA_CHECK(cudaGetDeviceProperties(&b, remaining[j])); + if (a.isMultiGpuBoard && b.isMultiGpuBoard) { + if (a.multiGpuBoardGroupID == b.multiGpuBoardGroupID) { + pairs->push_back(DevicePair(remaining[i], remaining[j])); + DLOG(INFO) << "GPU board: " << remaining[i] << ":" << remaining[j]; + remaining.erase(remaining.begin() + j); + break; + } + } + } + } + } + ostringstream s; + for (int i = 0; i < remaining.size(); ++i) { + s << (i ? ", " : "") << remaining[i]; + } + DLOG(INFO) << "GPUs paired by boards, remaining: " << s.str(); + + // Group by P2P accessibility + remaining_depth = ceil(log2(remaining.size())); + for (int d = 0; d < remaining_depth; ++d) { + for (int i = 0; i < remaining.size(); ++i) { + for (int j = i + 1; j < remaining.size(); ++j) { + int access; + CUDA_CHECK( + cudaDeviceCanAccessPeer(&access, remaining[i], remaining[j])); + if (access) { + pairs->push_back(DevicePair(remaining[i], remaining[j])); + DLOG(INFO) << "P2P pair: " << remaining[i] << ":" << remaining[j]; + remaining.erase(remaining.begin() + j); + break; + } + } + } + } + s.str(""); + for (int i = 0; i < remaining.size(); ++i) { + s << (i ? ", " : "") << remaining[i]; + } + DLOG(INFO) << "GPUs paired by P2P access, remaining: " << s.str(); + + // Group remaining + remaining_depth = ceil(log2(remaining.size())); + for (int d = 0; d < remaining_depth; ++d) { + for (int i = 0; i < remaining.size(); ++i) { + pairs->push_back(DevicePair(remaining[i], remaining[i + 1])); + DLOG(INFO) << "Remaining pair: " << remaining[i] << ":" + << remaining[i + 1]; + remaining.erase(remaining.begin() + i + 1); + } + } + + // Should only be the parent node remaining + CHECK_EQ(remaining.size(), 1); + + pairs->insert(pairs->begin(), DevicePair(-1, remaining[0])); + + CHECK(pairs->size() == devices.size()); + for (int i = 0; i < pairs->size(); ++i) { + CHECK((*pairs)[i].parent() != (*pairs)[i].device()); + for (int j = i + 1; j < pairs->size(); ++j) { + CHECK((*pairs)[i].device() != (*pairs)[j].device()); + } + } +#else + NO_GPU; +#endif +} + +// + +template +P2PSync::P2PSync(shared_ptr > root_solver, + P2PSync* parent, const SolverParameter& param) + : GPUParams(root_solver, param.device_id()), + parent_(parent), + children_(), + queue_(), + initial_iter_(root_solver->iter()), + solver_() { +#ifndef CPU_ONLY + int initial_device; + CUDA_CHECK(cudaGetDevice(&initial_device)); + const int self = param.device_id(); + CUDA_CHECK(cudaSetDevice(self)); + + if (parent == NULL) { + solver_ = root_solver; + } else { + Caffe::set_root_solver(false); + solver_.reset(new WorkerSolver(param, root_solver.get())); + Caffe::set_root_solver(true); + } + this->configure(solver_.get()); + solver_->add_callback(this); + + if (parent) { + // Enable p2p access between devices + const int peer = parent->solver_->param().device_id(); + int access; + CUDA_CHECK(cudaDeviceCanAccessPeer(&access, self, peer)); + if (access) { + CUDA_CHECK(cudaDeviceEnablePeerAccess(peer, 0)); + } else { + LOG(INFO)<< "GPU " << self << " does not have p2p access to GPU " << peer; + } + // Allocate receiving buffer on parent + CUDA_CHECK(cudaSetDevice(peer)); + CUDA_CHECK(cudaMalloc(&parent_grads_, size_ * sizeof(Dtype))); + CUDA_CHECK(cudaSetDevice(self)); + } + + CUDA_CHECK(cudaSetDevice(initial_device)); +#else + NO_GPU; +#endif +} + +template +P2PSync::~P2PSync() { +#ifndef CPU_ONLY + int initial_device; + CUDA_CHECK(cudaGetDevice(&initial_device)); + const int self = solver_->param().device_id(); + CUDA_CHECK(cudaSetDevice(self)); + + if (parent_) { + CUDA_CHECK(cudaFree(parent_grads_)); + const int peer = parent_->solver_->param().device_id(); + int access; + CUDA_CHECK(cudaDeviceCanAccessPeer(&access, self, peer)); + if (access) { + CUDA_CHECK(cudaDeviceDisablePeerAccess(peer)); + } + } + + CUDA_CHECK(cudaSetDevice(initial_device)); +#endif +} + +template +void P2PSync::InternalThreadEntry() { + Caffe::SetDevice(solver_->param().device_id()); + CHECK(Caffe::root_solver()); + Caffe::set_root_solver(false); + // See if there is a defined seed and reset random state if so + if (solver_->param().random_seed() >= 0) { + // Fetch random seed and modulate by device ID to make sure + // everyone doesn't have the same seed. We seem to have some + // solver instability if we have everyone with the same seed + Caffe::set_random_seed( + solver_->param().random_seed() + solver_->param().device_id()); + } + solver_->Step(solver_->param().max_iter() - initial_iter_); +} + +template +void P2PSync::on_start() { +#ifndef CPU_ONLY +#ifdef DEBUG + int device; + CUDA_CHECK(cudaGetDevice(&device)); + CHECK(device == solver_->param().device_id()); +#else +// CHECK(false); +#endif + + // Wait for update from parent + if (parent_) { + P2PSync *parent = queue_.pop(); + CHECK(parent == parent_); + } + + // Update children + for (int i = children_.size() - 1; i >= 0; i--) { + Dtype* src = data_; + Dtype* dst = children_[i]->data_; + +#ifdef DEBUG + cudaPointerAttributes attributes; + CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); + CHECK(attributes.device == device); + CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); + CHECK(attributes.device == children_[i]->solver_->param().device_id()); +#endif + + CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype), + cudaMemcpyDeviceToDevice, cudaStreamDefault)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); + children_[i]->queue_.push(this); + } +#endif +} + +template +void P2PSync::on_gradients_ready() { +#ifndef CPU_ONLY +#ifdef DEBUG + int device; + CUDA_CHECK(cudaGetDevice(&device)); + CHECK(device == solver_->param().device_id()); +#endif + + // Sum children gradients as they appear in the queue + for (int i = 0; i < children_.size(); ++i) { + P2PSync *child = queue_.pop(); + Dtype* src = child->parent_grads_; + Dtype* dst = diff_; + +#ifdef DEBUG + bool ok = false; + for (int j = 0; j < children_.size(); ++j) { + if (child == children_[j]) { + ok = true; + } + } + CHECK(ok); + cudaPointerAttributes attributes; + CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); + CHECK(attributes.device == device); + CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); + CHECK(attributes.device == device); +#endif + + caffe_gpu_add(size_, src, dst, dst); + } + + // Send gradients to parent + if (parent_) { + Dtype* src = diff_; + Dtype* dst = parent_grads_; + +#ifdef DEBUG + cudaPointerAttributes attributes; + CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); + CHECK(attributes.device == device); + CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); + CHECK(attributes.device == parent_->solver_->param().device_id()); +#endif + + CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype), // + cudaMemcpyDeviceToDevice, cudaStreamDefault)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); + parent_->queue_.push(this); + } else { + // Loss functions divide gradients by the batch size, so to compensate + // for split batch, the root solver divides by number of solvers. + caffe_gpu_scal(size_, Dtype(1.0 / Caffe::solver_count()), diff_); + } +#endif +} + +template +void P2PSync::run(const vector& gpus) { + // Pair devices for map-reduce synchronization + vector pairs; + DevicePair::compute(gpus, &pairs); + ostringstream s; + for (int i = 1; i < pairs.size(); ++i) { + s << (i == 1 ? "" : ", ") << pairs[i].parent() << ":" << pairs[i].device(); + } + LOG(INFO)<< "GPUs pairs " << s.str(); + + SolverParameter param(solver_->param()); + vector > > syncs(gpus.size()); + + // Build the GPU tree by finding the parent for each solver + for (int attempts = 0; attempts < pairs.size(); ++attempts) { + for (int i = 1; i < pairs.size(); ++i) { + if (!syncs[i].get()) { + P2PSync* parent = NULL; + for (int j = 0; j < syncs.size(); ++j) { + P2PSync* sync = j == 0 ? this : syncs[j].get(); + if (sync) { + const SolverParameter& p = sync->solver()->param(); + if (p.device_id() == pairs[i].parent()) { + parent = sync; + } + } + } + if (parent) { + param.set_device_id(pairs[i].device()); + syncs[i].reset(new P2PSync(solver_, parent, param)); + parent->children_.push_back((P2PSync*) syncs[i].get()); + } + } + } + } + + LOG(INFO)<< "Starting Optimization"; + + for (int i = 1; i < syncs.size(); ++i) { + syncs[i]->StartInternalThread(); + } + + // Run root solver on current thread + solver_->Solve(); + + for (int i = 1; i < syncs.size(); ++i) { + syncs[i]->StopInternalThread(); + } +} + +INSTANTIATE_CLASS(Params); +INSTANTIATE_CLASS(GPUParams); +INSTANTIATE_CLASS(P2PSync); + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 89f14595ba6..e78c6686049 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -500,6 +500,7 @@ message DataParameter { // to avoid all asynchronous sgd clients to start at the same point. The skip // point would be set as rand_skip * rand(0,1). Note that rand_skip should not // be larger than the number of keys in the database. + // DEPRECATED. Each solver accesses a different subset of the database. optional uint32 rand_skip = 7 [default = 0]; optional DB backend = 8 [default = LEVELDB]; // DEPRECATED. See TransformationParameter. For data pre-processing, we can do @@ -515,6 +516,9 @@ message DataParameter { optional bool mirror = 6 [default = false]; // Force the encoded image to have 3 color channels optional bool force_encoded_color = 9 [default = false]; + // Prefetch queue (Number of batches to prefetch to host memory, increase if + // data access bandwidth varies). + optional uint32 prefetch = 10 [default = 4]; } message DropoutParameter { @@ -736,6 +740,10 @@ message PythonParameter { // string, dictionary in Python dict format, JSON, etc. You may parse this // string in `setup` method and use it in `forward` and `backward`. optional string param_str = 3 [default = '']; + // Whether this PythonLayer is shared among worker solvers during data parallelism. + // If true, each worker solver sequentially run forward from this layer. + // This value should be set true if you are using it as a data layer. + optional bool share_in_parallel = 4 [default = false]; } // Message that stores parameters used by ReductionLayer diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 54e085a63e5..a44ff88dfd6 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -18,14 +18,14 @@ namespace caffe { template -Solver::Solver(const SolverParameter& param) - : net_() { +Solver::Solver(const SolverParameter& param, const Solver* root_solver) + : net_(), callbacks_(), root_solver_(root_solver) { Init(param); } template -Solver::Solver(const string& param_file) - : net_() { +Solver::Solver(const string& param_file, const Solver* root_solver) + : net_(), callbacks_(), root_solver_(root_solver) { SolverParameter param; ReadProtoFromTextFileOrDie(param_file, ¶m); Init(param); @@ -33,17 +33,21 @@ Solver::Solver(const string& param_file) template void Solver::Init(const SolverParameter& param) { - LOG(INFO) << "Initializing solver from parameters: " << std::endl - << param.DebugString(); + CHECK(Caffe::root_solver() || root_solver_) + << "root_solver_ needs to be set for all non-root solvers"; + LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: " + << std::endl << param.DebugString(); param_ = param; CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative."; - if (param_.random_seed() >= 0) { + if (Caffe::root_solver() && param_.random_seed() >= 0) { Caffe::set_random_seed(param_.random_seed()); } // Scaffolding code InitTrainNet(); - InitTestNets(); - LOG(INFO) << "Solver scaffolding done."; + if (Caffe::root_solver()) { + InitTestNets(); + LOG(INFO) << "Solver scaffolding done."; + } iter_ = 0; current_step_ = 0; } @@ -59,19 +63,22 @@ void Solver::InitTrainNet() { << "one of these fields specifying a train_net: " << field_names; NetParameter net_param; if (param_.has_train_net_param()) { - LOG(INFO) << "Creating training net specified in train_net_param."; + LOG_IF(INFO, Caffe::root_solver()) + << "Creating training net specified in train_net_param."; net_param.CopyFrom(param_.train_net_param()); } else if (param_.has_train_net()) { - LOG(INFO) << "Creating training net from train_net file: " - << param_.train_net(); + LOG_IF(INFO, Caffe::root_solver()) + << "Creating training net from train_net file: " << param_.train_net(); ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param); } if (param_.has_net_param()) { - LOG(INFO) << "Creating training net specified in net_param."; + LOG_IF(INFO, Caffe::root_solver()) + << "Creating training net specified in net_param."; net_param.CopyFrom(param_.net_param()); } if (param_.has_net()) { - LOG(INFO) << "Creating training net from net file: " << param_.net(); + LOG_IF(INFO, Caffe::root_solver()) + << "Creating training net from net file: " << param_.net(); ReadNetParamsFromTextFileOrDie(param_.net(), &net_param); } // Set the correct NetState. We start with the solver defaults (lowest @@ -83,11 +90,16 @@ void Solver::InitTrainNet() { net_state.MergeFrom(net_param.state()); net_state.MergeFrom(param_.train_state()); net_param.mutable_state()->CopyFrom(net_state); - net_.reset(new Net(net_param)); + if (Caffe::root_solver()) { + net_.reset(new Net(net_param)); + } else { + net_.reset(new Net(net_param, root_solver_->net_.get())); + } } template void Solver::InitTestNets() { + CHECK(Caffe::root_solver()); const bool has_net_param = param_.has_net_param(); const bool has_net_file = param_.has_net(); const int num_generic_nets = has_net_param + has_net_file; @@ -157,7 +169,12 @@ void Solver::InitTestNets() { net_params[i].mutable_state()->CopyFrom(net_state); LOG(INFO) << "Creating test net (#" << i << ") specified by " << sources[i]; - test_nets_[i].reset(new Net(net_params[i])); + if (Caffe::root_solver()) { + test_nets_[i].reset(new Net(net_params[i])); + } else { + test_nets_[i].reset(new Net(net_params[i], + root_solver_->test_nets_[i].get())); + } test_nets_[i]->set_debug_info(param_.debug_info()); } } @@ -175,10 +192,14 @@ void Solver::Step(int iters) { // zero-init the params net_->ClearParamDiffs(); if (param_.test_interval() && iter_ % param_.test_interval() == 0 - && (iter_ > 0 || param_.test_initialization())) { + && (iter_ > 0 || param_.test_initialization()) + && Caffe::root_solver()) { TestAll(); } + for (int i = 0; i < callbacks_.size(); ++i) { + callbacks_[i]->on_start(); + } const bool display = param_.display() && iter_ % param_.display() == 0; net_->set_debug_info(display && param_.debug_info()); // accumulate the loss and gradient @@ -198,7 +219,8 @@ void Solver::Step(int iters) { losses[idx] = loss; } if (display) { - LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss; + LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_ + << ", loss = " << smoothed_loss; const vector*>& result = net_->output_blobs(); int score_index = 0; for (int j = 0; j < result.size(); ++j) { @@ -213,12 +235,15 @@ void Solver::Step(int iters) { loss_msg_stream << " (* " << loss_weight << " = " << loss_weight * result_vec[k] << " loss)"; } - LOG(INFO) << " Train net output #" + LOG_IF(INFO, Caffe::root_solver()) << " Train net output #" << score_index++ << ": " << output_name << " = " << result_vec[k] << loss_msg_stream.str(); } } } + for (int i = 0; i < callbacks_.size(); ++i) { + callbacks_[i]->on_gradients_ready(); + } ApplyUpdate(); // Increment the internal iter_ counter -- its value should always indicate @@ -226,7 +251,9 @@ void Solver::Step(int iters) { ++iter_; // Save a snapshot if needed. - if (param_.snapshot() && iter_ % param_.snapshot() == 0) { + if (param_.snapshot() + && iter_ % param_.snapshot() == 0 + && Caffe::root_solver()) { Snapshot(); } } @@ -234,6 +261,7 @@ void Solver::Step(int iters) { template void Solver::Solve(const char* resume_file) { + CHECK(Caffe::root_solver()); LOG(INFO) << "Solving " << net_->name(); LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); @@ -278,6 +306,7 @@ void Solver::TestAll() { template void Solver::Test(const int test_net_id) { + CHECK(Caffe::root_solver()); LOG(INFO) << "Iteration " << iter_ << ", Testing net (#" << test_net_id << ")"; CHECK_NOTNULL(test_nets_[test_net_id].get())-> @@ -328,13 +357,14 @@ void Solver::Test(const int test_net_id) { << " = " << loss_weight * mean_score << " loss)"; } LOG(INFO) << " Test net output #" << i << ": " << output_name << " = " - << mean_score << loss_msg_stream.str(); + << mean_score << loss_msg_stream.str(); } } template void Solver::Snapshot() { + CHECK(Caffe::root_solver()); string model_filename; switch (param_.snapshot_format()) { case caffe::SolverParameter_SnapshotFormat_BINARYPROTO: @@ -379,6 +409,7 @@ string Solver::SnapshotToHDF5() { template void Solver::Restore(const char* state_file) { + CHECK(Caffe::root_solver()); string state_filename(state_file); if (state_filename.size() >= 3 && state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) { @@ -480,6 +511,7 @@ void SGDSolver::ClipGradients() { template void SGDSolver::ApplyUpdate() { + CHECK(Caffe::root_solver()); Dtype rate = GetLearningRate(); if (this->param_.display() && this->iter_ % this->param_.display() == 0) { LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; @@ -723,6 +755,7 @@ void SGDSolver::RestoreSolverStateFromHDF5(const string& state_file) { template void NesterovSolver::ComputeUpdateValue(int param_id, Dtype rate) { + CHECK(Caffe::root_solver()); const vector*>& net_params = this->net_->learnable_params(); const vector& net_params_lr = this->net_->params_lr(); Dtype momentum = this->param_.momentum(); @@ -783,6 +816,7 @@ void NesterovSolver::ComputeUpdateValue(int param_id, Dtype rate) { template void AdaGradSolver::ComputeUpdateValue(int param_id, Dtype rate) { + CHECK(Caffe::root_solver()); const vector*>& net_params = this->net_->learnable_params(); const vector& net_params_lr = this->net_->params_lr(); Dtype delta = this->param_.delta(); diff --git a/src/caffe/syncedmem.cpp b/src/caffe/syncedmem.cpp index 7617ccfb27f..a667a867af0 100644 --- a/src/caffe/syncedmem.cpp +++ b/src/caffe/syncedmem.cpp @@ -12,8 +12,14 @@ SyncedMemory::~SyncedMemory() { } #ifndef CPU_ONLY - if (gpu_ptr_) { + if (gpu_ptr_ && own_gpu_data_) { + int initial_device; + cudaGetDevice(&initial_device); + if (gpu_device_ != -1) { + CUDA_CHECK(cudaSetDevice(gpu_device_)); + } CUDA_CHECK(cudaFree(gpu_ptr_)); + cudaSetDevice(initial_device); } #endif // CPU_ONLY } @@ -48,13 +54,17 @@ inline void SyncedMemory::to_gpu() { #ifndef CPU_ONLY switch (head_) { case UNINITIALIZED: + CUDA_CHECK(cudaGetDevice(&gpu_device_)); CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); caffe_gpu_memset(size_, 0, gpu_ptr_); head_ = HEAD_AT_GPU; + own_gpu_data_ = true; break; case HEAD_AT_CPU: if (gpu_ptr_ == NULL) { + CUDA_CHECK(cudaGetDevice(&gpu_device_)); CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + own_gpu_data_ = true; } caffe_gpu_memcpy(size_, cpu_ptr_, gpu_ptr_); head_ = SYNCED; @@ -92,6 +102,26 @@ const void* SyncedMemory::gpu_data() { #endif } +void SyncedMemory::set_gpu_data(void* data) { +#ifndef CPU_ONLY + CHECK(data); + if (own_gpu_data_) { + int initial_device; + cudaGetDevice(&initial_device); + if (gpu_device_ != -1) { + CUDA_CHECK(cudaSetDevice(gpu_device_)); + } + CUDA_CHECK(cudaFree(gpu_ptr_)); + cudaSetDevice(initial_device); + } + gpu_ptr_ = data; + head_ = HEAD_AT_GPU; + own_gpu_data_ = false; +#else + NO_GPU; +#endif +} + void* SyncedMemory::mutable_cpu_data() { to_cpu(); head_ = HEAD_AT_CPU; @@ -108,6 +138,20 @@ void* SyncedMemory::mutable_gpu_data() { #endif } +#ifndef CPU_ONLY +void SyncedMemory::async_gpu_push(const cudaStream_t& stream) { + CHECK(head_ == HEAD_AT_CPU); + if (gpu_ptr_ == NULL) { + CUDA_CHECK(cudaGetDevice(&gpu_device_)); + CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + own_gpu_data_ = true; + } + const cudaMemcpyKind put = cudaMemcpyHostToDevice; + CUDA_CHECK(cudaMemcpyAsync(gpu_ptr_, cpu_ptr_, size_, put, stream)); + // Assume caller will synchronize on the stream before use + head_ = SYNCED; +} +#endif } // namespace caffe diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index eaa7a759b9b..1cede07f774 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -8,6 +8,7 @@ #include "gtest/gtest.h" #include "caffe/common.hpp" +#include "caffe/parallel.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/solver.hpp" #include "caffe/util/io.hpp" @@ -35,6 +36,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { string snapshot_prefix_; shared_ptr > solver_; + shared_ptr > sync_; int seed_; // Dimensions are determined by generate_sample_data.py // TODO this is brittle and the hdf5 file should be checked instead. @@ -70,8 +72,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { string RunLeastSquaresSolver(const Dtype learning_rate, const Dtype weight_decay, const Dtype momentum, const int num_iters, - const int iter_size = 1, const bool snapshot = false, - const char* from_snapshot = NULL) { + const int iter_size = 1, const int devices = 1, + const bool snapshot = false, const char* from_snapshot = NULL) { ostringstream proto; proto << "snapshot_after_train: " << snapshot << " " @@ -184,7 +186,20 @@ class GradientBasedSolverTest : public MultiDeviceTest { this->solver_->net()->Forward(empty_bottom_vec); } } - this->solver_->Solve(); + if (devices == 1) { + this->solver_->Solve(); + } else { + LOG(INFO) << "Multi-GPU test on " << devices << " devices"; + vector gpus; + for (int i = 0; i < devices; ++i) { + gpus.push_back(i); + } + Caffe::set_solver_count(gpus.size()); + this->sync_.reset(new P2PSync( + this->solver_, NULL, this->solver_->param())); + this->sync_->run(gpus); + Caffe::set_solver_count(1); + } if (snapshot) { ostringstream resume_file; resume_file << snapshot_prefix_ << "/_iter_" << num_iters @@ -410,20 +425,38 @@ class GradientBasedSolverTest : public MultiDeviceTest { void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0, const Dtype weight_decay = 0.0, const Dtype momentum = 0.0, const int iter_to_check = 0) { - // Initialize the solver and run K (= iter_to_check) solver iterations. - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check); - - // Compute the (K+1)th update using the analytic least squares gradient. - vector > > updated_params; - ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum, - &updated_params); - - // Reinitialize the solver and run K+1 solver iterations. - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, - iter_to_check + 1); - - // Check that the solver's solution matches ours. - CheckLeastSquaresUpdate(updated_params); + const int kNum = num_; + const int kIterSize = 1; + // Test over all numbers of devices. + int available_devices = 1; +#ifndef CPU_ONLY + if (Caffe::mode() == Caffe::GPU) { + CUDA_CHECK(cudaGetDeviceCount(&available_devices)); + } +#endif + for (int devices = 1; devices <= available_devices; ++devices) { + // Configure batch size for single / multi device equivalence. + // Constant data is needed for multi device as for accumulation. + num_ = kNum * devices; + + // Initialize the solver and run K (= iter_to_check) solver iterations + // (on single device). + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, + iter_to_check, kIterSize, 1); + + // Compute the (K+1)th update using the analytic least squares gradient. + vector > > updated_params; + ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum, + &updated_params); + + // Reinitialize the solver and run K+1 solver iterations. + num_ = kNum; + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, + iter_to_check + 1, kIterSize, devices); + + // Check that the solver's solution matches ours. + CheckLeastSquaresUpdate(updated_params); + } } void TestSnapshot(const Dtype learning_rate = 1.0, @@ -433,8 +466,9 @@ class GradientBasedSolverTest : public MultiDeviceTest { const int total_num_iters = num_iters * 2; bool snapshot = false; const int kIterSize = 1; + const int kDevices = 1; RunLeastSquaresSolver(learning_rate, weight_decay, momentum, - total_num_iters, kIterSize, snapshot); + total_num_iters, kIterSize, kDevices, snapshot); // Save the resulting param values. vector > > param_copies; @@ -464,12 +498,13 @@ class GradientBasedSolverTest : public MultiDeviceTest { // Run the solver for num_iters iterations and snapshot. snapshot = true; string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay, - momentum, num_iters, kIterSize, snapshot); + momentum, num_iters, kIterSize, kDevices, snapshot); // Reinitialize the solver and run for num_iters more iterations. snapshot = false; RunLeastSquaresSolver(learning_rate, weight_decay, momentum, - total_num_iters, kIterSize, snapshot, snapshot_name.c_str()); + total_num_iters, kIterSize, kDevices, + snapshot, snapshot_name.c_str()); // Check that params now match. const vector*>& params = solver_->net()->learnable_params(); diff --git a/src/caffe/test/test_internal_thread.cpp b/src/caffe/test/test_internal_thread.cpp index 31882b6db1d..93f1cc541cd 100644 --- a/src/caffe/test/test_internal_thread.cpp +++ b/src/caffe/test/test_internal_thread.cpp @@ -2,6 +2,7 @@ #include "gtest/gtest.h" #include "caffe/internal_thread.hpp" +#include "caffe/util/math_functions.hpp" #include "caffe/test/test_caffe_main.hpp" @@ -13,11 +14,40 @@ class InternalThreadTest : public ::testing::Test {}; TEST_F(InternalThreadTest, TestStartAndExit) { InternalThread thread; EXPECT_FALSE(thread.is_started()); - EXPECT_TRUE(thread.StartInternalThread()); + thread.StartInternalThread(); EXPECT_TRUE(thread.is_started()); - EXPECT_TRUE(thread.WaitForInternalThreadToExit()); + thread.StopInternalThread(); EXPECT_FALSE(thread.is_started()); } +class TestThreadA : public InternalThread { + void InternalThreadEntry() { + EXPECT_EQ(4244559767, caffe_rng_rand()); + } +}; + +class TestThreadB : public InternalThread { + void InternalThreadEntry() { + EXPECT_EQ(1726478280, caffe_rng_rand()); + } +}; + +TEST_F(InternalThreadTest, TestRandomSeed) { + TestThreadA t1; + Caffe::set_random_seed(9658361); + t1.StartInternalThread(); + t1.StopInternalThread(); + + TestThreadA t2; + Caffe::set_random_seed(9658361); + t2.StartInternalThread(); + t2.StopInternalThread(); + + TestThreadB t3; + Caffe::set_random_seed(3435563); + t3.StartInternalThread(); + t3.StopInternalThread(); +} + } // namespace caffe diff --git a/src/caffe/test/test_layer_factory.cpp b/src/caffe/test/test_layer_factory.cpp index efb1b37ac42..c86fafd000c 100644 --- a/src/caffe/test/test_layer_factory.cpp +++ b/src/caffe/test/test_layer_factory.cpp @@ -1,11 +1,14 @@ #include #include +#include "boost/scoped_ptr.hpp" #include "gtest/gtest.h" #include "caffe/common.hpp" #include "caffe/layer.hpp" #include "caffe/layer_factory.hpp" +#include "caffe/util/db.hpp" +#include "caffe/util/io.hpp" #include "caffe/test/test_caffe_main.hpp" @@ -21,11 +24,20 @@ TYPED_TEST(LayerFactoryTest, TestCreateLayer) { typename LayerRegistry::CreatorRegistry& registry = LayerRegistry::Registry(); shared_ptr > layer; - LayerParameter layer_param; for (typename LayerRegistry::CreatorRegistry::iterator iter = registry.begin(); iter != registry.end(); ++iter) { // Special case: PythonLayer is checked by pytest if (iter->first == "Python") { continue; } + LayerParameter layer_param; + // Data layers expect a DB + if (iter->first == "Data") { + string tmp; + MakeTempDir(&tmp); + boost::scoped_ptr db(db::GetDB(DataParameter_DB_LEVELDB)); + db->Open(tmp, db::NEW); + db->Close(); + layer_param.mutable_data_param()->set_source(tmp); + } layer_param.set_type(iter->first); layer = LayerRegistry::CreateLayer(layer_param); EXPECT_EQ(iter->first, layer->type()); diff --git a/src/caffe/test/test_upgrade_proto.cpp b/src/caffe/test/test_upgrade_proto.cpp index eec627656ef..006720231a5 100644 --- a/src/caffe/test/test_upgrade_proto.cpp +++ b/src/caffe/test/test_upgrade_proto.cpp @@ -2,12 +2,15 @@ #include #include +#include "boost/scoped_ptr.hpp" #include "google/protobuf/text_format.h" #include "gtest/gtest.h" #include "caffe/blob.hpp" #include "caffe/common.hpp" #include "caffe/layer.hpp" +#include "caffe/util/db.hpp" +#include "caffe/util/io.hpp" #include "caffe/util/upgrade_proto.hpp" #include "caffe/test/test_caffe_main.hpp" @@ -2901,6 +2904,15 @@ TEST_F(NetUpgradeTest, TestUpgradeV1LayerType) { continue; // Empty string isn't actually a valid layer type. } layer_param.set_type(v2_layer_type); + // Data layers expect a DB + if (v2_layer_type == "Data") { + string tmp; + MakeTempDir(&tmp); + boost::scoped_ptr db(db::GetDB(DataParameter_DB_LEVELDB)); + db->Open(tmp, db::NEW); + db->Close(); + layer_param.mutable_data_param()->set_source(tmp); + } layer = LayerRegistry::CreateLayer(layer_param); EXPECT_EQ(v2_layer_type, layer->type()); } diff --git a/src/caffe/util/blocking_queue.cpp b/src/caffe/util/blocking_queue.cpp new file mode 100644 index 00000000000..d1d1fa864c3 --- /dev/null +++ b/src/caffe/util/blocking_queue.cpp @@ -0,0 +1,96 @@ +#include +#include + +#include "caffe/data_layers.hpp" +#include "caffe/data_reader.hpp" +#include "caffe/parallel.hpp" +#include "caffe/util/blocking_queue.hpp" + +namespace caffe { + +template +class BlockingQueue::sync { + public: + mutable boost::mutex mutex_; + boost::condition_variable condition_; +}; + +template +BlockingQueue::BlockingQueue() + : sync_(new sync()) { +} + +template +void BlockingQueue::push(const T& t) { + boost::mutex::scoped_lock lock(sync_->mutex_); + queue_.push(t); + lock.unlock(); + sync_->condition_.notify_one(); +} + +template +bool BlockingQueue::try_pop(T* t) { + boost::mutex::scoped_lock lock(sync_->mutex_); + + if (queue_.empty()) { + return false; + } + + *t = queue_.front(); + queue_.pop(); + return true; +} + +template +T BlockingQueue::pop(const string& log_on_wait) { + boost::mutex::scoped_lock lock(sync_->mutex_); + + while (queue_.empty()) { + if (!log_on_wait.empty()) { + LOG_EVERY_N(INFO, 1000)<< log_on_wait; + } + sync_->condition_.wait(lock); + } + + T t = queue_.front(); + queue_.pop(); + return t; +} + +template +bool BlockingQueue::try_peek(T* t) { + boost::mutex::scoped_lock lock(sync_->mutex_); + + if (queue_.empty()) { + return false; + } + + *t = queue_.front(); + return true; +} + +template +T BlockingQueue::peek() { + boost::mutex::scoped_lock lock(sync_->mutex_); + + while (queue_.empty()) { + sync_->condition_.wait(lock); + } + + return queue_.front(); +} + +template +size_t BlockingQueue::size() const { + boost::mutex::scoped_lock lock(sync_->mutex_); + return queue_.size(); +} + +template class BlockingQueue*>; +template class BlockingQueue*>; +template class BlockingQueue; +template class BlockingQueue >; +template class BlockingQueue*>; +template class BlockingQueue*>; + +} // namespace caffe diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 46f99594800..9f31b37ac2b 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -17,13 +17,17 @@ using caffe::Blob; using caffe::Caffe; using caffe::Net; using caffe::Layer; +using caffe::Solver; using caffe::shared_ptr; +using caffe::string; using caffe::Timer; using caffe::vector; +using std::ostringstream; - -DEFINE_int32(gpu, -1, - "Run in GPU mode on given device ID."); +DEFINE_string(gpu, "", + "Optional; run in GPU mode on given device IDs separated by ','." + "Use '-gpu all' to run on all available GPUs. The effective training " + "batch size is multiplied by the number of devices."); DEFINE_string(solver, "", "The solver definition protocol buffer text file."); DEFINE_string(model, "", @@ -31,8 +35,8 @@ DEFINE_string(model, "", DEFINE_string(snapshot, "", "Optional; the snapshot solver state to resume training."); DEFINE_string(weights, "", - "Optional; the pretrained weights to initialize finetuning. " - "Cannot be set simultaneously with snapshot."); + "Optional; the pretrained weights to initialize finetuning, " + "separated by ','. Cannot be set simultaneously with snapshot."); DEFINE_int32(iterations, 50, "The number of iterations to run."); @@ -66,6 +70,29 @@ static BrewFunction GetBrewFunction(const caffe::string& name) { } } +// Parse GPU ids or use all available devices +static void get_gpus(vector* gpus) { + if (FLAGS_gpu == "all") { + int count = 0; +#ifndef CPU_ONLY + CUDA_CHECK(cudaGetDeviceCount(&count)); +#else + NO_GPU; +#endif + for (int i = 0; i < count; ++i) { + gpus->push_back(i); + } + } else if (FLAGS_gpu.size()) { + vector strings; + boost::split(strings, FLAGS_gpu, boost::is_any_of(",")); + for (int i = 0; i < strings.size(); ++i) { + gpus->push_back(boost::lexical_cast(strings[i])); + } + } else { + CHECK_EQ(gpus->size(), 0); + } +} + // caffe commands to call by // caffe // @@ -74,10 +101,13 @@ static BrewFunction GetBrewFunction(const caffe::string& name) { // Device Query: show diagnostic information for a GPU device. int device_query() { - CHECK_GT(FLAGS_gpu, -1) << "Need a device ID to query."; - LOG(INFO) << "Querying device ID = " << FLAGS_gpu; - caffe::Caffe::SetDevice(FLAGS_gpu); - caffe::Caffe::DeviceQuery(); + LOG(INFO) << "Querying GPUs " << FLAGS_gpu; + vector gpus; + get_gpus(&gpus); + for (int i = 0; i < gpus.size(); ++i) { + caffe::Caffe::SetDevice(gpus[i]); + caffe::Caffe::DeviceQuery(); + } return 0; } RegisterBrewFunction(device_query); @@ -106,34 +136,49 @@ int train() { caffe::SolverParameter solver_param; caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param); - // If the gpu flag is not provided, allow the mode and device to be set + // If the gpus flag is not provided, allow the mode and device to be set // in the solver prototxt. - if (FLAGS_gpu < 0 + if (FLAGS_gpu.size() == 0 && solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) { - FLAGS_gpu = solver_param.device_id(); + if (solver_param.has_device_id()) { + FLAGS_gpu = "" + + boost::lexical_cast(solver_param.device_id()); + } else { // Set default GPU if unspecified + FLAGS_gpu = "" + boost::lexical_cast(0); + } } - // Set device id and mode - if (FLAGS_gpu >= 0) { - LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu; - Caffe::SetDevice(FLAGS_gpu); - Caffe::set_mode(Caffe::GPU); - } else { - LOG(INFO) << "Use CPU."; + vector gpus; + get_gpus(&gpus); + if (gpus.size() == 0) { Caffe::set_mode(Caffe::CPU); + } else { + ostringstream s; + for (int i = 0; i < gpus.size(); ++i) { + s << (i ? ", " : "") << gpus[i]; + } + LOG(INFO) << "Using GPUs " << s.str(); + + solver_param.set_device_id(gpus[0]); + Caffe::SetDevice(gpus[0]); + Caffe::set_mode(Caffe::GPU); + Caffe::set_solver_count(gpus.size()); } - LOG(INFO) << "Starting Optimization"; - shared_ptr > - solver(caffe::GetSolver(solver_param)); + shared_ptr > solver(caffe::GetSolver(solver_param)); if (FLAGS_snapshot.size()) { LOG(INFO) << "Resuming from " << FLAGS_snapshot; - solver->Solve(FLAGS_snapshot); + solver->Restore(FLAGS_snapshot.c_str()); } else if (FLAGS_weights.size()) { - CopyLayers(&*solver, FLAGS_weights); - solver->Solve(); + CopyLayers(solver.get(), FLAGS_weights); + } + + if (gpus.size() > 1) { + caffe::P2PSync sync(solver, NULL, solver->param()); + sync.run(gpus); } else { + LOG(INFO) << "Starting Optimization"; solver->Solve(); } LOG(INFO) << "Optimization Done."; @@ -148,9 +193,11 @@ int test() { CHECK_GT(FLAGS_weights.size(), 0) << "Need model weights to score."; // Set device id and mode - if (FLAGS_gpu >= 0) { - LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu; - Caffe::SetDevice(FLAGS_gpu); + vector gpus; + get_gpus(&gpus); + if (gpus.size() != 0) { + LOG(INFO) << "Use GPU with device ID " << gpus[0]; + Caffe::SetDevice(gpus[0]); Caffe::set_mode(Caffe::GPU); } else { LOG(INFO) << "Use CPU."; @@ -213,9 +260,11 @@ int time() { CHECK_GT(FLAGS_model.size(), 0) << "Need a model definition to time."; // Set device id and mode - if (FLAGS_gpu >= 0) { - LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu; - Caffe::SetDevice(FLAGS_gpu); + vector gpus; + get_gpus(&gpus); + if (gpus.size() != 0) { + LOG(INFO) << "Use GPU with device ID " << gpus[0]; + Caffe::SetDevice(gpus[0]); Caffe::set_mode(Caffe::GPU); } else { LOG(INFO) << "Use CPU.";