Skip to content
This repository has been archived by the owner on Jan 10, 2023. It is now read-only.

Commit

Permalink
Global gradient function registration (#421)
Browse files Browse the repository at this point in the history
  • Loading branch information
ringgaard authored Nov 2, 2019
1 parent 1b6468a commit 5cb3d0c
Show file tree
Hide file tree
Showing 40 changed files with 1,595 additions and 182 deletions.
11 changes: 10 additions & 1 deletion sling/file/file.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,15 +331,24 @@ size_t File::PageSize() {
return sysconf(_SC_PAGESIZE);
}

void *File::MapMemory(uint64 pos, size_t size) {
void *File::MapMemory(uint64 pos, size_t size, bool writable) {
return nullptr;
}

Status File::FlushMappedMemory(void *data, size_t size) {
if (default_file_system == nullptr) return NoFileSystem("mmunmap");
return default_file_system->FlushMappedMemory(data, size);
}

Status File::FreeMappedMemory(void *data, size_t size) {
if (default_file_system == nullptr) return NoFileSystem("mmunmap");
return default_file_system->FreeMappedMemory(data, size);
}

Status FileSystem::FlushMappedMemory(void *data, size_t size) {
return Status(ENOSYS, "Memory-mapped files not supported");
}

Status FileSystem::FreeMappedMemory(void *data, size_t size) {
return Status(ENOSYS, "Memory-mapped files not supported");
}
Expand Down
13 changes: 12 additions & 1 deletion sling/file/file.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class File {
Status WriteLine(const string &line);

// Map file region into memory. Return null on error or if not supported.
virtual void *MapMemory(uint64 pos, size_t size);
virtual void *MapMemory(uint64 pos, size_t size, bool writable = false);

// Set the current file position.
virtual Status Seek(uint64 pos) = 0;
Expand Down Expand Up @@ -147,6 +147,11 @@ class File {
// Find file names matching pattern.
static Status Match(const string &pattern,
std::vector<string> *filenames);
static std::vector<string> Match(const string &pattern) {
std::vector<string> filenames;
CHECK(Match(pattern, &filenames));
return filenames;
}

// Read contents of file.
static Status ReadContents(const string &filename, string *data);
Expand All @@ -161,6 +166,9 @@ class File {
// Return page size for memory mapping.
static size_t PageSize();

// Flush mapped memory to disk.
static Status FlushMappedMemory(void *data, size_t size);

// Free memory mapping.
static Status FreeMappedMemory(void *data, size_t size);
};
Expand Down Expand Up @@ -210,6 +218,9 @@ class FileSystem : public Singleton<FileSystem> {
virtual Status Match(const string &pattern,
std::vector<string> *filenames) = 0;

// Flish mapped memory to disk.
virtual Status FlushMappedMemory(void *data, size_t size);

// Release mapped memory.
virtual Status FreeMappedMemory(void *data, size_t size);
};
Expand Down
9 changes: 7 additions & 2 deletions sling/file/posix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ class PosixFile : public File {
return Status::OK;
}

void *MapMemory(uint64 pos, size_t size) override {
void *MapMemory(uint64 pos, size_t size, bool writable) override {
void *mapping = mmap(nullptr, size, PROT_READ | PROT_WRITE,
MAP_PRIVATE, fd_, pos);
writable ? MAP_SHARED : MAP_PRIVATE, fd_, pos);
return mapping == MAP_FAILED ? nullptr : mapping;
}

Expand Down Expand Up @@ -271,6 +271,11 @@ class PosixFileSystem : public FileSystem {
return Status::OK;
}

Status FlushMappedMemory(void *data, size_t size) {
if (msync(data, size, MS_SYNC) != 0) return IOError("msync", errno);
return Status::OK;
}

Status FreeMappedMemory(void *data, size_t size) override {
if (munmap(data, size) != 0) return IOError("munmap", errno);
return Status::OK;
Expand Down
1 change: 1 addition & 0 deletions sling/file/recordio.cc
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ RecordWriter::RecordWriter(RecordReader *reader,
file_ = reader->file();
info_ = reader->info();
info_.index_page_size = options.index_page_size;
position_ = reader->size();
}

RecordWriter::~RecordWriter() {
Expand Down
10 changes: 7 additions & 3 deletions sling/myelin/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,14 @@ class FlowBuilder : public Scope {
return Div(Sum(x), Const(size));
}
Variable *Count(Variable *p, Type type = DT_FLOAT) {
return Op("Count", {p}, type, {});
return NoGradient(Op("Count", {p}, type, {}));
}
Variable *ArgMin(Variable *x) {
return NoGradient(Op("ArgMin", {x}, DT_INT32, {}));
}
Variable *ArgMax(Variable *x) {
return NoGradient(Op("ArgMax", {x}, DT_INT32, {}));
}
Variable *ArgMin(Variable *x) { return Op("ArgMin", {x}, DT_INT32, {}); }
Variable *ArgMax(Variable *x) { return Op("ArgMax", {x}, DT_INT32, {}); }

// Dot product between two vectors.
Variable *DotProduct(Variable *x, Variable *y) {
Expand Down
23 changes: 0 additions & 23 deletions sling/myelin/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,29 +250,6 @@ void Compiler::WriteGraph(const Flow &flow,
}
}

void LogProfile(const Network &net) {
if (net.options().global_profiler) {
LOG(INFO) << "Profiling report:\n" << ProfileReport(net);
}
}

string ProfileReport(const Network &net) {
string report;
if (net.options().global_profiler) {
ProfileOverview overview;
for (const Cell *cell : net.cells()) {
Profile profile(cell->profile_summary());
report.append(profile.ASCIIReport());
report.append("\n");
overview.Add(profile);
}
report.append("Summary:\n");
report.append(overview.ASCIIReport());
report.append("\n");
}
return report;
}

void SetCPUFeatures(const string &features) {
const char *p = features.c_str();

Expand Down
6 changes: 0 additions & 6 deletions sling/myelin/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,6 @@ class Compiler {
// Enable/disable CPU features for compiler.
void SetCPUFeatures(const string &features);

// Log profile report if profiling enabled.
void LogProfile(const Network &net);

// Return profile report if profiling enabled.
string ProfileReport(const Network &net);

} // namespace myelin
} // namespace sling

Expand Down
16 changes: 0 additions & 16 deletions sling/myelin/flow.h
Original file line number Diff line number Diff line change
Expand Up @@ -854,9 +854,6 @@ class Transformer {
// Flow graph transformations.
class Transformations {
public:
// Gradient function for differentiation of ops.
typedef void (GradientFunc)(Flow::Operation *op, Gradients *g);

~Transformations();

// Register flow transformation component. Transfers ownership from caller.
Expand All @@ -869,11 +866,6 @@ class Transformations {
typers_.emplace_back(typer);
}

// Register gradient function for op.
void RegisterGradient(const string &op, GradientFunc *func) {
gradients_[op] = func;
}

// Flow transformation components.
const std::vector<Transformer *> &transformers() const {
return transformers_;
Expand All @@ -884,20 +876,12 @@ class Transformations {
return typers_;
}

// Gradient functions.
const std::unordered_map<string, GradientFunc *> &gradients() const {
return gradients_;
}

private:
// Flow transformation components.
std::vector<Transformer *> transformers_;

// Type inference components.
std::vector<Typer *> typers_;

// Gradient components.
std::unordered_map<string, GradientFunc *> gradients_;
};

// Return name of corresponding gradient variable.
Expand Down
17 changes: 14 additions & 3 deletions sling/myelin/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
namespace sling {
namespace myelin {

// Registered gradient components.
static GradientFuncs gradient_funcs;

// Return last part of name.
static string basename(const string &name) {
int slash = name.rfind('/');
Expand Down Expand Up @@ -136,7 +139,7 @@ void Gradients::MarkReferences() {

Flow::Function *Gradient(Flow *flow,
Flow::Function *func,
const Transformations &library) {
const GradientFuncs &funcs) {
// Get variables for gradients.
std::vector<Flow::Variable *> vars;
std::vector<Flow::Operation *> ops;
Expand All @@ -148,8 +151,8 @@ Flow::Function *Gradient(Flow *flow,
Flow::Operation *op = ops[i];
if (op->is(Flow::Operation::NOGRADIENT)) continue;

auto f = library.gradients().find(op->type);
if (f == library.gradients().end()) {
auto f = funcs.find(op->type);
if (f == funcs.end()) {
LOG(FATAL) << "No gradient function for " << op->type;
}
auto *gradfunc = f->second;
Expand Down Expand Up @@ -188,6 +191,14 @@ Flow::Function *Gradient(Flow *flow,
return gradient;
}

void RegisterGradient(const string &op, GradientFunc *func) {
gradient_funcs[op] = func;
}

Flow::Function *Gradient(Flow *flow, Flow::Function *func) {
return Gradient(flow, func, gradient_funcs);
}

} // namespace myelin
} // namespace sling

13 changes: 11 additions & 2 deletions sling/myelin/gradient.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef SLING_MYELIN_GRADIENT_H_
#define SLING_MYELIN_GRADIENT_H_

#include <string>
#include <vector>
#include <unordered_map>

Expand Down Expand Up @@ -76,13 +77,21 @@ class Gradients : public FlowBuilder {
std::unordered_map<Flow::Variable *, Flow::Variable *> refs_;
};

// Gradient function for differentiation of ops.
typedef void (GradientFunc)(Flow::Operation *op, Gradients *g);
typedef std::unordered_map<string, GradientFunc *> GradientFuncs;

// Register gradient function.
void RegisterGradient(const string &op, GradientFunc *func);

// Build gradient for function.
Flow::Function *Gradient(Flow *flow,
Flow::Function *func,
const Transformations &library);
const GradientFuncs &funcs);
Flow::Function *Gradient(Flow *flow, Flow::Function *func);

} // namespace myelin
} // namespace sling

#endif // SLING_MYELIN_BUILDER_H_
#endif // SLING_MYELIN_GRADIENT_H_

2 changes: 2 additions & 0 deletions sling/myelin/kernel/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,7 @@ class PoolingGather : public Kernel {
int vecbytes = SIMDAssembler::VectorBytes(type);
bool aligned = M->stride(0) % vecbytes == 0;
SIMDAssembler sasm(masm, type, aligned);
step->set_variant(sasm.name());

// Compute vector processing strategy.
SIMDStrategy strategy(&sasm, n);
Expand Down Expand Up @@ -1160,6 +1161,7 @@ class AssignAddScatter : public Kernel {
int vecbytes = SIMDAssembler::VectorBytes(type);
bool aligned = args.var->stride(0) % vecbytes == 0;
SIMDAssembler sasm(masm, type, aligned);
step->set_variant(sasm.name());

// Compute vector processing strategy.
SIMDStrategy strategy(&sasm, n);
Expand Down
Loading

0 comments on commit 5cb3d0c

Please sign in to comment.