Skip to content

Commit

Permalink
Add recorder
Browse files Browse the repository at this point in the history
  • Loading branch information
Algy committed Sep 8, 2019
1 parent 05369dc commit 876a74a
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 2 deletions.
5 changes: 5 additions & 0 deletions cfast_slic.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ cdef extern from "src/context.h" namespace "fslic":
float preemptive_thres

bool manhattan_spatial_dist
bool debug_mode

BaseContext(int H, int W, int K, const uint8_t* image, Cluster *clusters) except +
void initialize_clusters() nogil
void initialize_state() nogil
bool parallelism_supported() nogil
void iterate(uint16_t *assignment, int max_iter) nogil except +
string get_timing_report();
string get_recorder_report();

cdef cppclass Context(BaseContext[uint16_t]):
Context(int H, int W, int K, const uint8_t* image, Cluster *clusters) except +
Expand Down Expand Up @@ -111,8 +113,11 @@ cdef class SlicModel:
cdef public object preemptive
cdef public float preemptive_thres
cdef public object manhattan_spatial_dist
cdef public object debug_mode

cdef public object float_color
cdef public object last_timing_report
cdef public object last_recorder_report

cpdef void initialize(self, const uint8_t [:, :, ::1] image)
cpdef iterate(self, const uint8_t [:, :, ::1] image, int max_iter, float compactness, float min_size_factor, uint8_t subsample_stride)
Expand Down
5 changes: 5 additions & 0 deletions cfast_slic.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ cdef class SlicModel:
self.real_dist_type = "standard"
self.convert_to_lab = False
self.float_color = True
self.debug_mode = False

self._c_clusters = <cfast_slic.Cluster *>malloc(sizeof(cfast_slic.Cluster) * num_components)
memset(self._c_clusters, 0, sizeof(cfast_slic.Cluster) * num_components)
Expand Down Expand Up @@ -183,6 +184,7 @@ cdef class SlicModel:
context.preemptive = self.preemptive
context.preemptive_thres = self.preemptive_thres
context.manhattan_spatial_dist = self.manhattan_spatial_dist
context.debug_mode = self.debug_mode
with nogil:
context.initialize_state()
context.iterate(
Expand All @@ -191,6 +193,7 @@ cdef class SlicModel:
)
finally:
self.last_timing_report = context.get_timing_report().decode("utf-8")
self.last_recorder_report = context.get_recorder_report()
del context
else:
if self.real_dist_type == 'standard':
Expand Down Expand Up @@ -241,6 +244,7 @@ cdef class SlicModel:
context_real_dist.preemptive = self.preemptive
context_real_dist.preemptive_thres = self.preemptive_thres
context_real_dist.manhattan_spatial_dist = self.manhattan_spatial_dist
context_real_dist.debug_mode = self.debug_mode
with nogil:
context_real_dist.initialize_state()
context_real_dist.iterate(
Expand All @@ -249,6 +253,7 @@ cdef class SlicModel:
)
finally:
self.last_timing_report = context_real_dist.get_timing_report().decode("utf-8")
self.last_recorder_report = context_real_dist.get_recorder_report()
del context_real_dist
result = assignments.astype(np.int16)
result[result == 0xFFFF] = -1
Expand Down
2 changes: 2 additions & 0 deletions fast_slic/base_slic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self,
preemptive=False,
preemptive_thres=0.05,
manhattan_spatial_dist=False,
debug_mode=False,
num_threads=-1):
self.compactness = compactness
self.subsample_stride = subsample_stride
Expand All @@ -25,6 +26,7 @@ def __init__(self,
self._slic_model.preemptive_thres = preemptive_thres
self._slic_model.manhattan_spatial_dist = manhattan_spatial_dist
self._slic_model.num_threads = num_threads
self._slic_model.debug_mode = debug_mode

@property
def convert_to_lab(self):
Expand Down
4 changes: 3 additions & 1 deletion src/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ namespace fslic {
before_iteration();
}
preemptive_grid.initialize(clusters, preemptive, preemptive_thres, subsample_stride);

recorder.initialize(debug_mode);
recorder.push(-1, this->assignment, this->min_dists, this->clusters);
for (int i = 0; i < max_iter; i++) {
{
fstimer::Scope s("assign");
Expand All @@ -169,6 +170,7 @@ namespace fslic {
fstimer::Scope s("after_update");
after_update();
}
recorder.push(i, this->assignment, this->min_dists, this->clusters);
subsample_rem = (subsample_rem + 1) % subsample_stride;
}
preemptive_grid.finalize(clusters);
Expand Down
7 changes: 6 additions & 1 deletion src/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "simd-helper.hpp"
#include "fast-slic-common.h"
#include "preemptive.h"
#include "recorder.h"

typedef std::chrono::high_resolution_clock Clock;

Expand All @@ -32,6 +33,7 @@ namespace fslic {
float preemptive_thres = 0.01;

bool manhattan_spatial_dist = true;
bool debug_mode = false;
protected:
int H, W, K;
const uint8_t* image;
Expand All @@ -50,6 +52,7 @@ namespace fslic {
simd_helper::AlignedArray<DistType> spatial_dist_patch;

PreemptiveGrid preemptive_grid;
Recorder<DistType> recorder;
public:
std::string last_timing_report;
public:
Expand All @@ -59,7 +62,8 @@ namespace fslic {
assignment(H, W, S, S, S, S),
min_dists(H, W, S, S, S, S),
spatial_dist_patch(2 * S + 1, 2 * S + 1),
preemptive_grid(H, W, K, S) {};
preemptive_grid(H, W, K, S),
recorder(H, W, K) {};
virtual ~BaseContext();
public:
void initialize_clusters();
Expand All @@ -68,6 +72,7 @@ namespace fslic {
bool parallelism_supported();
void iterate(uint16_t *assignment, int max_iter);
std::string get_timing_report() { return last_timing_report; };
std::string get_recorder_report() { return recorder.get_report(); };
private:
void prepare_spatial();
void assign();
Expand Down
103 changes: 103 additions & 0 deletions src/recorder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#ifndef _FAST_SLIC_RECORDER_H
#define _FAST_SLIC_RECORDER_H
#include <vector>
#include <string>
#include <cstdint>
#include <sstream>
#include <algorithm>
#include "simd-helper.hpp"

namespace fslic {
template <class DistType>
struct RecorderSnapshot {
int iteration;
std::vector<uint16_t> assignment;
std::vector<DistType> min_dists;
std::vector<Cluster> clusters;

void gen(std::stringstream& stream) {
stream << "{\"iteration\": " << iteration;
stream << ", \"clusters\": [";

for (int i = 0; i < (int)clusters.size(); i++) {
Cluster &cluster = clusters[i];
if (i > 0) stream << ",";
stream << "{\"yx\": [" << cluster.y << "," << cluster.x << "]";
stream << ", \"color\": [" << cluster.r << "," << cluster.g << "," << cluster.b << "]";
stream << ", \"is_updatable\": " << (int)cluster.is_updatable;
stream << ", \"is_active\": " << (int)cluster.is_active;
stream << ", \"number\": " << cluster.number;
stream << ", \"num_members\": " << cluster.num_members;
stream << "}";
}
stream << "]";
stream << ", \"assignment\": [";
for (int i = 0; i < (int)assignment.size(); i++) {
if (i > 0) stream << ",";
stream << assignment[i];
}
stream << "]";
stream << ", \"min_dists\": [";
for (int i = 0; i < (int)min_dists.size(); i++) {
if (i > 0) stream << ",";
stream << min_dists[i];
}
stream << "]";
stream << "}";
}
};

template <class DistType>
class Recorder {
private:
std::vector<RecorderSnapshot<DistType>> snapshots;
int H, W, K;
bool enabled;
public:
Recorder(int H, int W, int K) : H(H), W(W), K(K), enabled(false) {};
void initialize(bool enabled) {
this->enabled = enabled;
}
void push(
int iter,
const simd_helper::AlignedArray<uint16_t> &assignment,
const simd_helper::AlignedArray<DistType> &min_dists,
const Cluster *clusters) {
if (!enabled) return;
RecorderSnapshot<DistType> snapshot;
snapshot.iteration = iter;
snapshot.assignment.resize(H * W);
snapshot.min_dists.resize(H * W);
snapshot.clusters.resize(K);

for (int i = 0; i < H; i++) {
for (int j = 0; j < W; j++) {
snapshot.assignment[W * i + j] = assignment.get(i, j);
snapshot.min_dists[W * i + j] = min_dists.get(i, j);
}
}
std::copy(clusters, clusters + K, snapshot.clusters.begin());
snapshots.push_back(snapshot);
}

std::string get_report() {
std::stringstream stream;
gen(stream);
return stream.str();
}

private:
void gen(std::stringstream& stream) {
stream << "{\"height\": " << H;
stream << ", \"width\": " << W;
stream << ", \"snapshots\": [";
for (int i = 0; i < (int)snapshots.size(); i++) {
if (i > 0) stream << ",";
snapshots[i].gen(stream);
}
stream << "]}";
}
};
}

#endif

0 comments on commit 876a74a

Please sign in to comment.