diff --git a/modules/dnn/CMakeLists.txt b/modules/dnn/CMakeLists.txt index fafb82257b8..782ae94d91f 100644 --- a/modules/dnn/CMakeLists.txt +++ b/modules/dnn/CMakeLists.txt @@ -42,8 +42,12 @@ if(BUILD_TESTS AND ${the_module}_DOWNLOAD_CAFFE_MODELS) add_custom_command( TARGET opencv_test_${name} POST_BUILD COMMAND ${PYTHON2_EXECUTABLE} download_model.py test_models.json WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/scripts ) -else() - add_definitions(-DDISABLE_CAFFE_MODEL_TESTS=1) + add_definitions(-DENABLE_CAFFE_MODEL_TESTS=1) +endif() + +OCV_OPTION(${the_module}_BUILD_TORCH_IMPORTER "Build Torch model importer" OFF) +if(${the_module}_BUILD_TORCH_IMPORTER) + add_definitions(-DENABLE_TORCH_IMPORTER=1) endif() else()#build as standalone module (for development purposes) diff --git a/modules/dnn/include/opencv2/dnn/dnn.hpp b/modules/dnn/include/opencv2/dnn/dnn.hpp index fd23164b6fc..83230ac91cc 100644 --- a/modules/dnn/include/opencv2/dnn/dnn.hpp +++ b/modules/dnn/include/opencv2/dnn/dnn.hpp @@ -90,6 +90,8 @@ namespace dnn CV_EXPORTS Ptr createCaffeImporter(const String &prototxt, const String &caffeModel = String()); + CV_EXPORTS Ptr createTorchImporter(const String &filename, bool isBinary = true); + //Layer factory allows to create instances of registered layers. class CV_EXPORTS LayerRegister { diff --git a/modules/dnn/src/torch/COPYRIGHT.txt b/modules/dnn/src/torch/COPYRIGHT.txt new file mode 100644 index 00000000000..c9cc78475c6 --- /dev/null +++ b/modules/dnn/src/torch/COPYRIGHT.txt @@ -0,0 +1,36 @@ +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/modules/dnn/src/torch/File.lua b/modules/dnn/src/torch/File.lua new file mode 100644 index 00000000000..1b86171b736 --- /dev/null +++ b/modules/dnn/src/torch/File.lua @@ -0,0 +1,348 @@ +local File = torch.getmetatable('torch.File') + +function File:writeBool(value) + if value then + self:writeInt(1) + else + self:writeInt(0) + end +end + +function File:readBool() + return (self:readInt() == 1) +end + +local TYPE_NIL = 0 +local TYPE_NUMBER = 1 +local TYPE_STRING = 2 +local TYPE_TABLE = 3 +local TYPE_TORCH = 4 +local TYPE_BOOLEAN = 5 +local TYPE_FUNCTION = 6 +local TYPE_RECUR_FUNCTION = 8 +local LEGACY_TYPE_RECUR_FUNCTION = 7 + +-- Lua 5.2 compatibility +local loadstring = loadstring or load + +function File:isWritableObject(object) + local typename = type(object) + local typeidx + if type(object) ~= 'boolean' and not object then + typeidx = TYPE_NIL + elseif torch.typename(object) and torch.factory(torch.typename(object)) then + typeidx = TYPE_TORCH + elseif typename == 'table' then + typeidx = TYPE_TABLE + elseif typename == 'number' then + typeidx = TYPE_NUMBER + elseif typename == 'string' then + typeidx = TYPE_STRING + elseif typename == 'boolean' then + typeidx = TYPE_BOOLEAN + elseif typename == 'function' and pcall(string.dump, object) then + typeidx = TYPE_RECUR_FUNCTION + end + return typeidx +end + +function File:referenced(ref) + -- we use an environment to keep a record of written objects + if not torch.getenv(self).writeObjects then + torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}}) + end + local env = torch.getenv(self) + env.force = not ref + torch.setenv(self,env) + return self +end + +function File:isReferenced() + -- if no environment, then no forcing setup yet + if not torch.getenv(self).writeObjects then + return true + end + local env = torch.getenv(self) + return not env.force +end + +local function getmetamethod(obj, name) + local func + local status + + -- check getmetatable(obj).__name or + -- check getmetatable(obj).name + status, func = pcall( + function() + -- note that sometimes the metatable is hidden + -- we get it for sure through the torch type system + local mt = torch.getmetatable(torch.typename(obj)) + if mt then + return mt['__' .. name] or mt[name] + end + end + ) + if status and type(func) == 'function' then + return func + end +end + +function File:writeObject(object) + -- we use an environment to keep a record of written objects + if not torch.getenv(self).writeObjects then + torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}}) + end + + local force = torch.getenv(self).force + + -- if nil object, only write the type and return + if type(object) ~= 'boolean' and not object then + self:writeInt(TYPE_NIL) + return + end + + -- check the type we are dealing with + local typeidx = self:isWritableObject(object) + if not typeidx then + error(string.format('Unwritable object <%s>', type(object))) + end + self:writeInt(typeidx) + + if typeidx == TYPE_NUMBER then + self:writeDouble(object) + elseif typeidx == TYPE_BOOLEAN then + self:writeBool(object) + elseif typeidx == TYPE_STRING then + local stringStorage = torch.CharStorage():string(object) + self:writeInt(#stringStorage) + self:writeChar(stringStorage) + elseif typeidx == TYPE_TORCH or typeidx == TYPE_TABLE or typeidx == TYPE_RECUR_FUNCTION then + -- check it exists already (we look at the pointer!) + local objects = torch.getenv(self).writeObjects + local objectsRef = torch.getenv(self).writeObjectsRef + local index = objects[torch.pointer(object)] + + if index and (not force) then + -- if already exists, write only its index + self:writeInt(index) + else + -- else write the object itself + index = objects.nWriteObject or 0 + index = index + 1 + objects[torch.pointer(object)] = index + if not force then + objectsRef[object] = index -- we make sure the object is not going to disappear + end + self:writeInt(index) + objects.nWriteObject = index + if typeidx == TYPE_RECUR_FUNCTION then + local upvalues = {} + local counter = 0 + while true do + counter = counter + 1 + local name,value = debug.getupvalue(object, counter) + if not name then break end + if name == '_ENV' then value = nil end + table.insert(upvalues, {name=name, value=value}) + end + local dumped = string.dump(object) + local stringStorage = torch.CharStorage():string(dumped) + self:writeInt(#stringStorage) + self:writeChar(stringStorage) + self:writeObject(upvalues) + elseif typeidx == TYPE_TORCH then + local version = torch.CharStorage():string('V ' .. torch.version(object)) + local className = torch.CharStorage():string(torch.typename(object)) + self:writeInt(#version) + self:writeChar(version) + self:writeInt(#className) + self:writeChar(className) + local write = getmetamethod(object, 'write') + if write then + write(object, self) + elseif type(object) == 'table' then + local var = {} + for k,v in pairs(object) do + if self:isWritableObject(v) then + var[k] = v + else + print(string.format('$ Warning: cannot write object field <%s>', k)) + end + end + self:writeObject(var) + else + error(string.format('<%s> is a non-serializable Torch object', torch.typename(object))) + end + else -- it is a table + local size = 0; for k,v in pairs(object) do size = size + 1 end + self:writeInt(size) + for k,v in pairs(object) do + self:writeObject(k) + self:writeObject(v) + end + end + end + else + error('Unwritable object') + end +end + +function File:readObject() + -- we use an environment to keep a record of read objects + if not torch.getenv(self).writeObjects then + torch.setenv(self, {writeObjects={}, writeObjectsRef={}, readObjects={}}) + end + + local force = torch.getenv(self).force + + -- read the typeidx + local typeidx = self:readInt() + + -- is it nil? + if typeidx == TYPE_NIL then + return nil + end + + if typeidx == TYPE_NUMBER then + return self:readDouble() + elseif typeidx == TYPE_BOOLEAN then + return self:readBool() + elseif typeidx == TYPE_STRING then + local size = self:readInt() + return self:readChar(size):string() + elseif typeidx == TYPE_FUNCTION then + local size = self:readInt() + local dumped = self:readChar(size):string() + local func = loadstring(dumped) + local upvalues = self:readObject() + for index,upvalue in ipairs(upvalues) do + debug.setupvalue(func, index, upvalue) + end + return func + elseif typeidx == TYPE_TABLE or typeidx == TYPE_TORCH or typeidx == TYPE_RECUR_FUNCTION or typeidx == LEGACY_TYPE_RECUR_FUNCTION then + -- read the index + local index = self:readInt() + + -- check it is loaded already + local objects = torch.getenv(self).readObjects + if objects[index] and not force then + return objects[index] + end + + -- otherwise read it + if typeidx == TYPE_RECUR_FUNCTION or typeidx == LEGACY_TYPE_RECUR_FUNCTION then + local size = self:readInt() + local dumped = self:readChar(size):string() + local func = loadstring(dumped) + objects[index] = func + local upvalues = self:readObject() + for index,upvalue in ipairs(upvalues) do + if typeidx == LEGACY_TYPE_RECUR_FUNCTION then + debug.setupvalue(func, index, upvalue) + elseif upvalue.name == '_ENV' then + debug.setupvalue(func, index, _ENV) + else + debug.setupvalue(func, index, upvalue.value) + end + end + return func + elseif typeidx == TYPE_TORCH then + local version, className, versionNumber + version = self:readChar(self:readInt()):string() + versionNumber = tonumber(string.match(version, '^V (.*)$')) + if not versionNumber then + className = version + versionNumber = 0 -- file created before existence of versioning system + else + className = self:readChar(self:readInt()):string() + end + if not torch.factory(className) then + error(string.format('unknown Torch class <%s>', tostring(className))) + end + local object = torch.factory(className)(self) + objects[index] = object + local read = getmetamethod(object, 'read') + if read then + read(object, self, versionNumber) + elseif type(object) == 'table' then + local var = self:readObject() + for k,v in pairs(var) do + object[k] = v + end + else + error(string.format('Cannot load object class <%s>', tostring(className))) + end + return object + else -- it is a table + local size = self:readInt() + local object = {} + objects[index] = object + for i = 1,size do + local k = self:readObject() + local v = self:readObject() + object[k] = v + end + return object + end + else + error('unknown object') + end +end + +-- simple helpers to save/load arbitrary objects/tables +function torch.save(filename, object, mode) + mode = mode or 'binary' + local file = torch.DiskFile(filename, 'w') + file[mode](file) + file:writeObject(object) + file:close() +end + +function torch.load(filename, mode) + mode = mode or 'binary' + local file = torch.DiskFile(filename, 'r') + file[mode](file) + local object = file:readObject() + file:close() + return object +end + +-- simple helpers to serialize/deserialize arbitrary objects/tables +function torch.serialize(object, mode) + local storage = torch.serializeToStorage(object, mode) + return storage:string() +end + +-- Serialize to a CharStorage, not a lua string. This avoids +function torch.serializeToStorage(object, mode) + mode = mode or 'binary' + local f = torch.MemoryFile() + f = f[mode](f) + f:writeObject(object) + local storage = f:storage() + f:close() + return storage +end + +function torch.deserializeFromStorage(storage, mode) + mode = mode or 'binary' + local tx = torch.CharTensor(storage) + local xp = torch.CharStorage(tx:size(1)+1) + local txp = torch.CharTensor(xp) + txp:narrow(1,1,tx:size(1)):copy(tx) + txp[tx:size(1)+1] = 0 + local f = torch.MemoryFile(xp) + f = f[mode](f) + local object = f:readObject() + f:close() + return object +end + +function torch.deserialize(str, mode) + local storage = torch.CharStorage():string(str) + return torch.deserializeFromStorage(storage, mode) +end + +-- public API (saveobj/loadobj are safe for global import) +torch.saveobj = torch.save +torch.loadobj = torch.load diff --git a/modules/dnn/src/torch/THDiskFile.cpp b/modules/dnn/src/torch/THDiskFile.cpp new file mode 100644 index 00000000000..ca8247dae42 --- /dev/null +++ b/modules/dnn/src/torch/THDiskFile.cpp @@ -0,0 +1,609 @@ +#include "THGeneral.h" +#include "THDiskFile.h" +#include "THFilePrivate.h" + +extern "C" +{ + +typedef struct THDiskFile__ +{ + THFile file; + + FILE *handle; + char *name; + int isNativeEncoding; + +} THDiskFile; + +static int THDiskFile_isOpened(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)self; + return (dfself->handle != NULL); +} + +const char *THDiskFile_name(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)self; + return dfself->name; +} + +/* workaround mac osx lion ***insane*** fread bug */ +#ifdef __APPLE__ +size_t fread__(void *ptr, size_t size, size_t nitems, FILE *stream) +{ + size_t nread = 0; + while(!feof(stream) && !ferror(stream) && (nread < nitems)) + nread += fread((char*)ptr+nread*size, size, THMin(2147483648/size, nitems-nread), stream); + return nread; +} +#else +#define fread__ fread +#endif + +#define READ_WRITE_METHODS(TYPE, TYPEC, ASCII_READ_ELEM, ASCII_WRITE_ELEM) \ + static long THDiskFile_read##TYPEC(THFile *self, TYPE *data, long n) \ + { \ + THDiskFile *dfself = (THDiskFile*)(self); \ + long nread = 0L; \ + \ + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); \ + THArgCheck(dfself->file.isReadable, 1, "attempt to read in a write-only file"); \ + \ + if(dfself->file.isBinary) \ + { \ + nread = fread__(data, sizeof(TYPE), n, dfself->handle); \ + if(!dfself->isNativeEncoding && (sizeof(TYPE) > 1) && (nread > 0)) \ + THDiskFile_reverseMemory(data, data, sizeof(TYPE), nread); \ + } \ + else \ + { \ + long i; \ + for(i = 0; i < n; i++) \ + { \ + ASCII_READ_ELEM; /* increment here result and break if wrong */ \ + } \ + if(dfself->file.isAutoSpacing && (n > 0)) \ + { \ + int c = fgetc(dfself->handle); \ + if( (c != '\n') && (c != EOF) ) \ + ungetc(c, dfself->handle); \ + } \ + } \ + \ + if(nread != n) \ + { \ + dfself->file.hasError = 1; /* shouldn't we put hasError to 0 all the time ? */ \ + if(!dfself->file.isQuiet) \ + THError("read error: read %d blocks instead of %d", nread, n); \ + } \ + \ + return nread; \ + } \ + \ + static long THDiskFile_write##TYPEC(THFile *self, TYPE *data, long n) \ + { \ + THDiskFile *dfself = (THDiskFile*)(self); \ + long nwrite = 0L; \ + \ + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); \ + THArgCheck(dfself->file.isWritable, 1, "attempt to write in a read-only file"); \ + \ + if(dfself->file.isBinary) \ + { \ + if(dfself->isNativeEncoding) \ + { \ + nwrite = fwrite(data, sizeof(TYPE), n, dfself->handle); \ + } \ + else \ + { \ + if(sizeof(TYPE) > 1) \ + { \ + char *buffer = (char*)THAlloc(sizeof(TYPE)*n); \ + THDiskFile_reverseMemory(buffer, data, sizeof(TYPE), n); \ + nwrite = fwrite(buffer, sizeof(TYPE), n, dfself->handle); \ + THFree(buffer); \ + } \ + else \ + nwrite = fwrite(data, sizeof(TYPE), n, dfself->handle); \ + } \ + } \ + else \ + { \ + long i; \ + for(i = 0; i < n; i++) \ + { \ + ASCII_WRITE_ELEM; \ + if( dfself->file.isAutoSpacing && (i < n-1) ) \ + fprintf(dfself->handle, " "); \ + } \ + if(dfself->file.isAutoSpacing && (n > 0)) \ + fprintf(dfself->handle, "\n"); \ + } \ + \ + if(nwrite != n) \ + { \ + dfself->file.hasError = 1; \ + if(!dfself->file.isQuiet) \ + THError("write error: wrote %d blocks instead of %d", nwrite, n); \ + } \ + \ + return nwrite; \ +} + +static int THDiskFile_mode(const char *mode, int *isReadable, int *isWritable) +{ + *isReadable = 0; + *isWritable = 0; + if(strlen(mode) == 1) + { + if(*mode == 'r') + { + *isReadable = 1; + return 1; + } + else if(*mode == 'w') + { + *isWritable = 1; + return 1; + } + } + else if(strlen(mode) == 2) + { + if(mode[0] == 'r' && mode[1] == 'w') + { + *isReadable = 1; + *isWritable = 1; + return 1; + } + } + return 0; +} + +static void THDiskFile_synchronize(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + fflush(dfself->handle); +} + +static void THDiskFile_seek(THFile *self, long position) +{ + THDiskFile *dfself = (THDiskFile*)(self); + + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + THArgCheck(position >= 0, 2, "position must be positive"); + + if(fseek(dfself->handle, position, SEEK_SET) < 0) + { + dfself->file.hasError = 1; + if(!dfself->file.isQuiet) + THError("unable to seek at position %d", position); + } +} + +static void THDiskFile_seekEnd(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + + if(fseek(dfself->handle, 0L, SEEK_END) < 0) + { + dfself->file.hasError = 1; + if(!dfself->file.isQuiet) + THError("unable to seek at end of file"); + } +} + +static long THDiskFile_position(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + return ftell(dfself->handle); +} + +static void THDiskFile_close(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + fclose(dfself->handle); + dfself->handle = NULL; +} + +/* Little and Big Endian */ + +static void THDiskFile_reverseMemory(void *dst, const void *src, long blockSize, long numBlocks) +{ + if(blockSize != 1) + { + long halfBlockSize = blockSize/2; + char *charSrc = (char*)src; + char *charDst = (char*)dst; + long b, i; + for(b = 0; b < numBlocks; b++) + { + for(i = 0; i < halfBlockSize; i++) + { + char z = charSrc[i]; + charDst[i] = charSrc[blockSize-1-i]; + charDst[blockSize-1-i] = z; + } + charSrc += blockSize; + charDst += blockSize; + } + } +} + +int THDiskFile_isLittleEndianCPU(void) +{ + int x = 7; + char *ptr = (char *)&x; + + if(ptr[0] == 0) + return 0; + else + return 1; +} + +int THDiskFile_isBigEndianCPU(void) +{ + return(!THDiskFile_isLittleEndianCPU()); +} + +void THDiskFile_nativeEndianEncoding(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + dfself->isNativeEncoding = 1; +} + +void THDiskFile_littleEndianEncoding(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + dfself->isNativeEncoding = THDiskFile_isLittleEndianCPU(); +} + +void THDiskFile_bigEndianEncoding(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + dfself->isNativeEncoding = !THDiskFile_isLittleEndianCPU(); +} + +/* End of Little and Big Endian Stuff */ + +static void THDiskFile_free(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + if(dfself->handle) + fclose(dfself->handle); + THFree(dfself->name); + THFree(dfself); +} + +/* READ_WRITE_METHODS(int, Bool, */ +/* int value = 0; int ret = fscanf(file->handle, "%d", &value); array[i] = (value ? 1 : 0); if(ret <= 0) break; else result++, */ +/* int value = (array[i] ? 1 : 0); nElemWritten = fprintf(file->handle, "%d", value), */ +/* true) */ + +/* Note that we do a trick */ +READ_WRITE_METHODS(unsigned char, Byte, + nread = fread(data, 1, n, dfself->handle); break, + nwrite = fwrite(data, 1, n, dfself->handle); break) + +READ_WRITE_METHODS(char, Char, + nread = fread(data, 1, n, dfself->handle); break, + nwrite = fwrite(data, 1, n, dfself->handle); break) + +READ_WRITE_METHODS(short, Short, + int ret = fscanf(dfself->handle, "%hd", &data[i]); if(ret <= 0) break; else nread++, + int ret = fprintf(dfself->handle, "%hd", data[i]); if(ret <= 0) break; else nwrite++) + +READ_WRITE_METHODS(int, Int, + int ret = fscanf(dfself->handle, "%d", &data[i]); if(ret <= 0) break; else nread++, + int ret = fprintf(dfself->handle, "%d", data[i]); if(ret <= 0) break; else nwrite++) + +READ_WRITE_METHODS(long, Long, + int ret = fscanf(dfself->handle, "%ld", &data[i]); if(ret <= 0) break; else nread++, + int ret = fprintf(dfself->handle, "%ld", data[i]); if(ret <= 0) break; else nwrite++) + +READ_WRITE_METHODS(float, Float, + int ret = fscanf(dfself->handle, "%g", &data[i]); if(ret <= 0) break; else nread++, + int ret = fprintf(dfself->handle, "%.9g", data[i]); if(ret <= 0) break; else nwrite++) + +READ_WRITE_METHODS(double, Double, + int ret = fscanf(dfself->handle, "%lg", &data[i]); if(ret <= 0) break; else nread++, + int ret = fprintf(dfself->handle, "%.17g", data[i]); if(ret <= 0) break; else nwrite++) + +static long THDiskFile_readString(THFile *self, const char *format, char **str_) +{ + THDiskFile *dfself = (THDiskFile*)(self); + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + THArgCheck(dfself->file.isReadable, 1, "attempt to read in a write-only file"); + THArgCheck((strlen(format) >= 2 ? (format[0] == '*') && (format[1] == 'a' || format[1] == 'l') : 0), 2, "format must be '*a' or '*l'"); + +/* note: the string won't survive long, as it is copied into lua */ +/* so 1024 is not that big... */ +#define TBRS_BSZ 1024L + + if(format[1] == 'a') + { + char *p = (char*)THAlloc(TBRS_BSZ); + long total = TBRS_BSZ; + long pos = 0L; + + for (;;) + { + if(total-pos == 0) /* we need more space! */ + { + total += TBRS_BSZ; + p = (char*)THRealloc(p, total); + } + pos += fread(p+pos, 1, total-pos, dfself->handle); + if (pos < total) /* eof? */ + { + if(pos == 0L) + { + THFree(p); + dfself->file.hasError = 1; + if(!dfself->file.isQuiet) + THError("read error: read 0 blocks instead of 1"); + + *str_ = NULL; + return 0; + } + *str_ = p; + return pos; + } + } + } + else + { + char *p = (char*)THAlloc(TBRS_BSZ); + long total = TBRS_BSZ; + long pos = 0L; + long size; + + for (;;) + { + if(total-pos <= 1) /* we can only write '\0' in there! */ + { + total += TBRS_BSZ; + p = (char*)THRealloc(p, total); + } + if (fgets(p+pos, total-pos, dfself->handle) == NULL) /* eof? */ + { + if(pos == 0L) + { + THFree(p); + dfself->file.hasError = 1; + if(!dfself->file.isQuiet) + THError("read error: read 0 blocks instead of 1"); + + *str_ = NULL; + return 0; + } + *str_ = p; + return pos; + } + size = strlen(p+pos); + if (size == 0L || (p+pos)[size-1] != '\n') + { + pos += size; + } + else + { + pos += size-1L; /* do not include `eol' */ + *str_ = p; + return pos; + } + } + } + + *str_ = NULL; + return 0; +} + + +static long THDiskFile_writeString(THFile *self, const char *str, long size) +{ + THDiskFile *dfself = (THDiskFile*)(self); + long nwrite; + + THArgCheck(dfself->handle != NULL, 1, "attempt to use a closed file"); + THArgCheck(dfself->file.isWritable, 1, "attempt to write in a read-only file"); + + nwrite = fwrite(str, 1, size, dfself->handle); + if(nwrite != size) + { + dfself->file.hasError = 1; + if(!dfself->file.isQuiet) + THError("write error: wrote %ld blocks instead of %ld", nwrite, size); + } + + return nwrite; +} + +THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet) +{ + static struct THFileVTable vtable = { + THDiskFile_isOpened, + + THDiskFile_readByte, + THDiskFile_readChar, + THDiskFile_readShort, + THDiskFile_readInt, + THDiskFile_readLong, + THDiskFile_readFloat, + THDiskFile_readDouble, + THDiskFile_readString, + + THDiskFile_writeByte, + THDiskFile_writeChar, + THDiskFile_writeShort, + THDiskFile_writeInt, + THDiskFile_writeLong, + THDiskFile_writeFloat, + THDiskFile_writeDouble, + THDiskFile_writeString, + + THDiskFile_synchronize, + THDiskFile_seek, + THDiskFile_seekEnd, + THDiskFile_position, + THDiskFile_close, + THDiskFile_free + }; + + int isReadable; + int isWritable; + FILE *handle; + THDiskFile *self; + + THArgCheck(THDiskFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w' or 'rw'"); + + if( isReadable && isWritable ) + { + handle = fopen(name, "r+b"); + if(!handle) + { + handle = fopen(name, "wb"); + if(handle) + { + fclose(handle); + handle = fopen(name, "r+b"); + } + } + } + else + handle = fopen(name, (isReadable ? "rb" : "wb")); + + if(!handle) + { + if(isQuiet) + return 0; + else + THError("cannot open <%s> in mode %c%c", name, (isReadable ? 'r' : ' '), (isWritable ? 'w' : ' ')); + } + + self = (THDiskFile*)THAlloc(sizeof(THDiskFile)); + + self->handle = handle; + self->name = (char*)THAlloc(strlen(name)+1); + strcpy(self->name, name); + self->isNativeEncoding = 1; + + self->file.vtable = &vtable; + self->file.isQuiet = isQuiet; + self->file.isReadable = isReadable; + self->file.isWritable = isWritable; + self->file.isBinary = 0; + self->file.isAutoSpacing = 1; + self->file.hasError = 0; + + return (THFile*)self; +} + +/* PipeFile */ + +static int THPipeFile_mode(const char *mode, int *isReadable, int *isWritable) +{ + *isReadable = 0; + *isWritable = 0; + if(strlen(mode) == 1) + { + if(*mode == 'r') + { + *isReadable = 1; + return 1; + } + else if(*mode == 'w') + { + *isWritable = 1; + return 1; + } + } + return 0; +} + +static void THPipeFile_free(THFile *self) +{ + THDiskFile *dfself = (THDiskFile*)(self); + if(dfself->handle) + pclose(dfself->handle); + THFree(dfself->name); + THFree(dfself); +} + +THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet) +{ + static struct THFileVTable vtable = { + THDiskFile_isOpened, + + THDiskFile_readByte, + THDiskFile_readChar, + THDiskFile_readShort, + THDiskFile_readInt, + THDiskFile_readLong, + THDiskFile_readFloat, + THDiskFile_readDouble, + THDiskFile_readString, + + THDiskFile_writeByte, + THDiskFile_writeChar, + THDiskFile_writeShort, + THDiskFile_writeInt, + THDiskFile_writeLong, + THDiskFile_writeFloat, + THDiskFile_writeDouble, + THDiskFile_writeString, + + THDiskFile_synchronize, + THDiskFile_seek, + THDiskFile_seekEnd, + THDiskFile_position, + THDiskFile_close, + THPipeFile_free + }; + + int isReadable; + int isWritable; + FILE *handle; + THDiskFile *self; + + THArgCheck(THPipeFile_mode(mode, &isReadable, &isWritable), 2, "file mode should be 'r','w'"); + +#ifdef _WIN32 + handle = popen(name, (isReadable ? "rb" : "wb")); +#else + handle = popen(name, (isReadable ? "r" : "w")); +#endif + + if(!handle) + { + if(isQuiet) + return 0; + else + THError("cannot open <%s> in mode %c%c", name, (isReadable ? 'r' : ' '), (isWritable ? 'w' : ' ')); + } + + self = (THDiskFile*)THAlloc(sizeof(THDiskFile)); + + self->handle = handle; + self->name = (char*)THAlloc(strlen(name)+1); + strcpy(self->name, name); + self->isNativeEncoding = 1; + + self->file.vtable = &vtable; + self->file.isQuiet = isQuiet; + self->file.isReadable = isReadable; + self->file.isWritable = isWritable; + self->file.isBinary = 0; + self->file.isAutoSpacing = 1; + self->file.hasError = 0; + + return (THFile*)self; +} + +} diff --git a/modules/dnn/src/torch/THDiskFile.h b/modules/dnn/src/torch/THDiskFile.h new file mode 100644 index 00000000000..f7c93c220c6 --- /dev/null +++ b/modules/dnn/src/torch/THDiskFile.h @@ -0,0 +1,17 @@ +#ifndef TH_DISK_FILE_INC +#define TH_DISK_FILE_INC + +#include "THFile.h" + +TH_API THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet); +TH_API THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet); + +TH_API const char *THDiskFile_name(THFile *self); + +TH_API int THDiskFile_isLittleEndianCPU(void); +TH_API int THDiskFile_isBigEndianCPU(void); +TH_API void THDiskFile_nativeEndianEncoding(THFile *self); +TH_API void THDiskFile_littleEndianEncoding(THFile *self); +TH_API void THDiskFile_bigEndianEncoding(THFile *self); + +#endif diff --git a/modules/dnn/src/torch/THFile.cpp b/modules/dnn/src/torch/THFile.cpp new file mode 100644 index 00000000000..db71a066d8a --- /dev/null +++ b/modules/dnn/src/torch/THFile.cpp @@ -0,0 +1,161 @@ +#include "THFile.h" +#include "THFilePrivate.h" + +extern "C" +{ + +#define IMPLEMENT_THFILE_RW(TYPEC, TYPE) \ + long THFile_read##TYPEC##Raw(THFile *self, TYPE *data, long n) \ + { \ + return (*self->vtable->read##TYPEC)(self, data, n); \ + } \ + \ + long THFile_write##TYPEC##Raw(THFile *self, TYPE *data, long n) \ + { \ + return (*self->vtable->write##TYPEC)(self, data, n); \ + } + +IMPLEMENT_THFILE_RW(Byte, unsigned char) +IMPLEMENT_THFILE_RW(Char, char) +IMPLEMENT_THFILE_RW(Short, short) +IMPLEMENT_THFILE_RW(Int, int) +IMPLEMENT_THFILE_RW(Long, long) +IMPLEMENT_THFILE_RW(Float, float) +IMPLEMENT_THFILE_RW(Double, double) + +long THFile_readStringRaw(THFile *self, const char *format, char **str_) +{ + return self->vtable->readString(self, format, str_); +} + +long THFile_writeStringRaw(THFile *self, const char *str, long size) +{ + return self->vtable->writeString(self, str, size); +} + +void THFile_synchronize(THFile *self) +{ + self->vtable->synchronize(self); +} + +void THFile_seek(THFile *self, long position) +{ + self->vtable->seek(self, position); +} + +void THFile_seekEnd(THFile *self) +{ + self->vtable->seekEnd(self); +} + +long THFile_position(THFile *self) +{ + return self->vtable->position(self); +} + +void THFile_close(THFile *self) +{ + self->vtable->close(self); +} + +void THFile_free(THFile *self) +{ + self->vtable->free(self); +} + +int THFile_isOpened(THFile *self) +{ + return self->vtable->isOpened(self); +} + +#define IMPLEMENT_THFILE_FLAGS(FLAG) \ + int THFile_##FLAG(THFile *self) \ + { \ + return self->FLAG; \ + } + +IMPLEMENT_THFILE_FLAGS(isQuiet) +IMPLEMENT_THFILE_FLAGS(isReadable) +IMPLEMENT_THFILE_FLAGS(isWritable) +IMPLEMENT_THFILE_FLAGS(isBinary) +IMPLEMENT_THFILE_FLAGS(isAutoSpacing) +IMPLEMENT_THFILE_FLAGS(hasError) + +void THFile_binary(THFile *self) +{ + self->isBinary = 1; +} + +void THFile_ascii(THFile *self) +{ + self->isBinary = 0; +} + +void THFile_autoSpacing(THFile *self) +{ + self->isAutoSpacing = 1; +} + +void THFile_noAutoSpacing(THFile *self) +{ + self->isAutoSpacing = 0; +} + +void THFile_quiet(THFile *self) +{ + self->isQuiet = 1; +} + +void THFile_pedantic(THFile *self) +{ + self->isQuiet = 0; +} + +void THFile_clearError(THFile *self) +{ + self->hasError = 0; +} + +#define IMPLEMENT_THFILE_SCALAR(TYPEC, TYPE) \ + TYPE THFile_read##TYPEC##Scalar(THFile *self) \ + { \ + TYPE scalar; \ + THFile_read##TYPEC##Raw(self, &scalar, 1); \ + return scalar; \ + } \ + \ + void THFile_write##TYPEC##Scalar(THFile *self, TYPE scalar) \ + { \ + THFile_write##TYPEC##Raw(self, &scalar, 1); \ + } + +IMPLEMENT_THFILE_SCALAR(Byte, unsigned char) +IMPLEMENT_THFILE_SCALAR(Char, char) +IMPLEMENT_THFILE_SCALAR(Short, short) +IMPLEMENT_THFILE_SCALAR(Int, int) +IMPLEMENT_THFILE_SCALAR(Long, long) +IMPLEMENT_THFILE_SCALAR(Float, float) +IMPLEMENT_THFILE_SCALAR(Double, double) + +/* +#define IMPLEMENT_THFILE_STORAGE(TYPEC, TYPE) \ + long THFile_read##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \ + { \ + return THFile_read##TYPEC##Raw(self, storage->data, storage->size); \ + } \ + \ + long THFile_write##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \ + { \ + return THFile_write##TYPEC##Raw(self, storage->data, storage->size); \ + } + +IMPLEMENT_THFILE_STORAGE(Byte, unsigned char) +IMPLEMENT_THFILE_STORAGE(Char, char) +IMPLEMENT_THFILE_STORAGE(Short, short) +IMPLEMENT_THFILE_STORAGE(Int, int) +IMPLEMENT_THFILE_STORAGE(Long, long) +IMPLEMENT_THFILE_STORAGE(Float, float) +IMPLEMENT_THFILE_STORAGE(Double, double) +*/ + +} \ No newline at end of file diff --git a/modules/dnn/src/torch/THFile.h b/modules/dnn/src/torch/THFile.h new file mode 100644 index 00000000000..3fac5cc1a24 --- /dev/null +++ b/modules/dnn/src/torch/THFile.h @@ -0,0 +1,87 @@ +#ifndef TH_FILE_INC +#define TH_FILE_INC + +//#include "THStorage.h" +#include "THGeneral.h" + +typedef struct THFile__ THFile; + +TH_API int THFile_isOpened(THFile *self); +TH_API int THFile_isQuiet(THFile *self); +TH_API int THFile_isReadable(THFile *self); +TH_API int THFile_isWritable(THFile *self); +TH_API int THFile_isBinary(THFile *self); +TH_API int THFile_isAutoSpacing(THFile *self); +TH_API int THFile_hasError(THFile *self); + +TH_API void THFile_binary(THFile *self); +TH_API void THFile_ascii(THFile *self); +TH_API void THFile_autoSpacing(THFile *self); +TH_API void THFile_noAutoSpacing(THFile *self); +TH_API void THFile_quiet(THFile *self); +TH_API void THFile_pedantic(THFile *self); +TH_API void THFile_clearError(THFile *self); + +/* scalar */ +TH_API unsigned char THFile_readByteScalar(THFile *self); +TH_API char THFile_readCharScalar(THFile *self); +TH_API short THFile_readShortScalar(THFile *self); +TH_API int THFile_readIntScalar(THFile *self); +TH_API long THFile_readLongScalar(THFile *self); +TH_API float THFile_readFloatScalar(THFile *self); +TH_API double THFile_readDoubleScalar(THFile *self); + +TH_API void THFile_writeByteScalar(THFile *self, unsigned char scalar); +TH_API void THFile_writeCharScalar(THFile *self, char scalar); +TH_API void THFile_writeShortScalar(THFile *self, short scalar); +TH_API void THFile_writeIntScalar(THFile *self, int scalar); +TH_API void THFile_writeLongScalar(THFile *self, long scalar); +TH_API void THFile_writeFloatScalar(THFile *self, float scalar); +TH_API void THFile_writeDoubleScalar(THFile *self, double scalar); + +/* storage */ +/* +TH_API long THFile_readByte(THFile *self, THByteStorage *storage); +TH_API long THFile_readChar(THFile *self, THCharStorage *storage); +TH_API long THFile_readShort(THFile *self, THShortStorage *storage); +TH_API long THFile_readInt(THFile *self, THIntStorage *storage); +TH_API long THFile_readLong(THFile *self, THLongStorage *storage); +TH_API long THFile_readFloat(THFile *self, THFloatStorage *storage); +TH_API long THFile_readDouble(THFile *self, THDoubleStorage *storage); + +TH_API long THFile_writeByte(THFile *self, THByteStorage *storage); +TH_API long THFile_writeChar(THFile *self, THCharStorage *storage); +TH_API long THFile_writeShort(THFile *self, THShortStorage *storage); +TH_API long THFile_writeInt(THFile *self, THIntStorage *storage); +TH_API long THFile_writeLong(THFile *self, THLongStorage *storage); +TH_API long THFile_writeFloat(THFile *self, THFloatStorage *storage); +TH_API long THFile_writeDouble(THFile *self, THDoubleStorage *storage); +*/ + +/* raw */ +TH_API long THFile_readByteRaw(THFile *self, unsigned char *data, long n); +TH_API long THFile_readCharRaw(THFile *self, char *data, long n); +TH_API long THFile_readShortRaw(THFile *self, short *data, long n); +TH_API long THFile_readIntRaw(THFile *self, int *data, long n); +TH_API long THFile_readLongRaw(THFile *self, long *data, long n); +TH_API long THFile_readFloatRaw(THFile *self, float *data, long n); +TH_API long THFile_readDoubleRaw(THFile *self, double *data, long n); +TH_API long THFile_readStringRaw(THFile *self, const char *format, char **str_); /* you must deallocate str_ */ + +TH_API long THFile_writeByteRaw(THFile *self, unsigned char *data, long n); +TH_API long THFile_writeCharRaw(THFile *self, char *data, long n); +TH_API long THFile_writeShortRaw(THFile *self, short *data, long n); +TH_API long THFile_writeIntRaw(THFile *self, int *data, long n); +TH_API long THFile_writeLongRaw(THFile *self, long *data, long n); +TH_API long THFile_writeFloatRaw(THFile *self, float *data, long n); +TH_API long THFile_writeDoubleRaw(THFile *self, double *data, long n); +TH_API long THFile_writeStringRaw(THFile *self, const char *str, long size); + +TH_API void THFile_synchronize(THFile *self); +TH_API void THFile_seek(THFile *self, long position); +TH_API void THFile_seekEnd(THFile *self); +TH_API long THFile_position(THFile *self); +TH_API void THFile_close(THFile *self); +TH_API void THFile_free(THFile *self); + +#endif diff --git a/modules/dnn/src/torch/THFilePrivate.h b/modules/dnn/src/torch/THFilePrivate.h new file mode 100644 index 00000000000..9097fb9798e --- /dev/null +++ b/modules/dnn/src/torch/THFilePrivate.h @@ -0,0 +1,43 @@ +struct THFile__ +{ + struct THFileVTable *vtable; + + int isQuiet; + int isReadable; + int isWritable; + int isBinary; + int isAutoSpacing; + int hasError; +}; + +/* virtual table definition */ + +struct THFileVTable +{ + int (*isOpened)(THFile *self); + + long (*readByte)(THFile *self, unsigned char *data, long n); + long (*readChar)(THFile *self, char *data, long n); + long (*readShort)(THFile *self, short *data, long n); + long (*readInt)(THFile *self, int *data, long n); + long (*readLong)(THFile *self, long *data, long n); + long (*readFloat)(THFile *self, float *data, long n); + long (*readDouble)(THFile *self, double *data, long n); + long (*readString)(THFile *self, const char *format, char **str_); + + long (*writeByte)(THFile *self, unsigned char *data, long n); + long (*writeChar)(THFile *self, char *data, long n); + long (*writeShort)(THFile *self, short *data, long n); + long (*writeInt)(THFile *self, int *data, long n); + long (*writeLong)(THFile *self, long *data, long n); + long (*writeFloat)(THFile *self, float *data, long n); + long (*writeDouble)(THFile *self, double *data, long n); + long (*writeString)(THFile *self, const char *str, long size); + + void (*synchronize)(THFile *self); + void (*seek)(THFile *self, long position); + void (*seekEnd)(THFile *self); + long (*position)(THFile *self); + void (*close)(THFile *self); + void (*free)(THFile *self); +}; diff --git a/modules/dnn/src/torch/THGeneral.cpp b/modules/dnn/src/torch/THGeneral.cpp new file mode 100644 index 00000000000..792c5516a38 --- /dev/null +++ b/modules/dnn/src/torch/THGeneral.cpp @@ -0,0 +1,254 @@ +#include "THGeneral.h" + +extern "C" +{ + +#ifndef TH_HAVE_THREAD +#define __thread +#endif + +#if defined(TH_DISABLE_HEAP_TRACKING) +#elif (defined(__unix) || defined(_WIN32)) +#include +#elif defined(__APPLE__) +#include +#endif + +/* Torch Error Handling */ +static void defaultTorchErrorHandlerFunction(const char *msg, void *data) +{ + printf("$ Error: %s\n", msg); + exit(-1); +} + +static __thread void (*torchErrorHandlerFunction)(const char *msg, void *data) = defaultTorchErrorHandlerFunction; +static __thread void *torchErrorHandlerData; + +void _THError(const char *file, const int line, const char *fmt, ...) +{ + char msg[2048]; + va_list args; + + /* vasprintf not standard */ + /* vsnprintf: how to handle if does not exists? */ + va_start(args, fmt); + int n = vsnprintf(msg, 2048, fmt, args); + va_end(args); + + if(n < 2048) { + snprintf(msg + n, 2048 - n, " at %s:%d", file, line); + } + + (*torchErrorHandlerFunction)(msg, torchErrorHandlerData); +} + +void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...) { + char msg[1024]; + va_list args; + va_start(args, fmt); + vsnprintf(msg, 1024, fmt, args); + va_end(args); + _THError(file, line, "Assertion `%s' failed. %s", exp, msg); +} + +void THSetErrorHandler( void (*torchErrorHandlerFunction_)(const char *msg, void *data), void *data ) +{ + if(torchErrorHandlerFunction_) + torchErrorHandlerFunction = torchErrorHandlerFunction_; + else + torchErrorHandlerFunction = defaultTorchErrorHandlerFunction; + torchErrorHandlerData = data; +} + +/* Torch Arg Checking Handling */ +static void defaultTorchArgErrorHandlerFunction(int argNumber, const char *msg, void *data) +{ + if(msg) + printf("$ Invalid argument %d: %s\n", argNumber, msg); + else + printf("$ Invalid argument %d\n", argNumber); + exit(-1); +} + +static __thread void (*torchArgErrorHandlerFunction)(int argNumber, const char *msg, void *data) = defaultTorchArgErrorHandlerFunction; +static __thread void *torchArgErrorHandlerData; + +void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...) +{ + if(!condition) { + char msg[2048]; + va_list args; + + /* vasprintf not standard */ + /* vsnprintf: how to handle if does not exists? */ + va_start(args, fmt); + int n = vsnprintf(msg, 2048, fmt, args); + va_end(args); + + if(n < 2048) { + snprintf(msg + n, 2048 - n, " at %s:%d", file, line); + } + + (*torchArgErrorHandlerFunction)(argNumber, msg, torchArgErrorHandlerData); + } +} + +void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction_)(int argNumber, const char *msg, void *data), void *data ) +{ + if(torchArgErrorHandlerFunction_) + torchArgErrorHandlerFunction = torchArgErrorHandlerFunction_; + else + torchArgErrorHandlerFunction = defaultTorchArgErrorHandlerFunction; + torchArgErrorHandlerData = data; +} + +static __thread void (*torchGCFunction)(void *data) = NULL; +static __thread void *torchGCData; +static __thread long torchHeapSize = 0; +static __thread long torchHeapSizeSoftMax = 300000000; // 300MB, adjusted upward dynamically + +/* Optional hook for integrating with a garbage-collected frontend. + * + * If torch is running with a garbage-collected frontend (e.g. Lua), + * the GC isn't aware of TH-allocated memory so may not know when it + * needs to run. These hooks trigger the GC to run in two cases: + * + * (1) When a memory allocation (malloc, realloc, ...) fails + * (2) When the total TH-allocated memory hits a dynamically-adjusted + * soft maximum. + */ +void THSetGCHandler( void (*torchGCFunction_)(void *data), void *data ) +{ + torchGCFunction = torchGCFunction_; + torchGCData = data; +} + +static long getAllocSize(void *ptr) { +#if defined(TH_DISABLE_HEAP_TRACKING) + return 0; +#elif defined(__unix) + return malloc_usable_size(ptr); +#elif defined(__APPLE__) + return malloc_size(ptr); +#elif defined(_WIN32) + return _msize(ptr); +#else + return 0; +#endif +} + +/* (1) if the torch-allocated heap size exceeds the soft max, run GC + * (2) if post-GC heap size exceeds 80% of the soft max, increase the + * soft max by 40% + */ +static void maybeTriggerGC() { + if(torchGCFunction && torchHeapSize > torchHeapSizeSoftMax) { + torchGCFunction(torchGCData); + if(torchHeapSize > torchHeapSizeSoftMax * 0.8) { + torchHeapSizeSoftMax = torchHeapSizeSoftMax * 1.4; + } + } +} + +// hooks into the TH heap tracking +void THHeapUpdate(long size) { + torchHeapSize += size; + if (size > 0) + maybeTriggerGC(); +} + +static void* THAllocInternal(long size) +{ + void *ptr; + + if (size > 5120) + { +#if (defined(__unix) || defined(__APPLE__)) && (!defined(DISABLE_POSIX_MEMALIGN)) + if (posix_memalign(&ptr, 64, size) != 0) + ptr = NULL; +/* +#elif defined(_WIN32) + ptr = _aligned_malloc(size, 64); +*/ +#else + ptr = malloc(size); +#endif + } + else + { + ptr = malloc(size); + } + + THHeapUpdate(getAllocSize(ptr)); + return ptr; +} + +void* THAlloc(long size) +{ + void *ptr; + + if(size < 0) + THError("$ Torch: invalid memory size -- maybe an overflow?"); + + if(size == 0) + return NULL; + + ptr = THAllocInternal(size); + + if(!ptr && torchGCFunction) { + torchGCFunction(torchGCData); + ptr = THAllocInternal(size); + } + + if(!ptr) + THError("$ Torch: not enough memory: you tried to allocate %dGB. Buy new RAM!", size/1073741824); + + return ptr; +} + +void* THRealloc(void *ptr, long size) +{ + if(!ptr) + return(THAlloc(size)); + + if(size == 0) + { + THFree(ptr); + return NULL; + } + + if(size < 0) + THError("$ Torch: invalid memory size -- maybe an overflow?"); + + THHeapUpdate(-getAllocSize(ptr)); + void *newptr = realloc(ptr, size); + + if(!newptr && torchGCFunction) { + torchGCFunction(torchGCData); + newptr = realloc(ptr, size); + } + THHeapUpdate(getAllocSize(newptr ? newptr : ptr)); + + if(!newptr) + THError("$ Torch: not enough memory: you tried to reallocate %dGB. Buy new RAM!", size/1073741824); + + return newptr; +} + +void THFree(void *ptr) +{ + THHeapUpdate(-getAllocSize(ptr)); + free(ptr); +} + +double THLog1p(const double x) +{ +#if (defined(_MSC_VER) || defined(__MINGW32__)) + volatile double y = 1 + x; + return log(y) - ((y-1)-x)/y ; /* cancels errors with IEEE arithmetic */ +#else + return log1p(x); +#endif +} + +} diff --git a/modules/dnn/src/torch/THGeneral.h b/modules/dnn/src/torch/THGeneral.h new file mode 100644 index 00000000000..8d33ede3821 --- /dev/null +++ b/modules/dnn/src/torch/THGeneral.h @@ -0,0 +1,89 @@ +#ifndef TH_GENERAL_INC +#define TH_GENERAL_INC + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __cplusplus +# define TH_EXTERNC extern "C" +#else +# define TH_EXTERNC extern +#endif + +#define TH_API TH_EXTERNC + +#define THInf DBL_MAX + +//#define TH_INLINE @TH_INLINE@ + +#ifndef __cplusplus +//#define inline @TH_INLINE@ +#endif + +#ifndef M_PI +# define M_PI 3.14159265358979323846 +#endif + +TH_API double THLog1p(const double x); +TH_API void _THError(const char *file, const int line, const char *fmt, ...); +TH_API void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...); +TH_API void THSetErrorHandler( void (*torchErrorHandlerFunction)(const char *msg, void *data), void *data ); +TH_API void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...); +TH_API void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction)(int argNumber, const char *msg, void *data), void *data ); +TH_API void* THAlloc(long size); +TH_API void* THRealloc(void *ptr, long size); +TH_API void THFree(void *ptr); +TH_API void THSetGCHandler( void (*torchGCHandlerFunction)(void *data), void *data ); +// this hook should only be called by custom allocator functions +TH_API void THHeapUpdate(long size); + +#define THError(...) _THError(__FILE__, __LINE__, __VA_ARGS__) +#define THArgCheck(...) _THArgCheck(__FILE__, __LINE__, __VA_ARGS__) +#define THAssert(exp) \ +do { \ + if (!(exp)) { \ + _THAssertionFailed(__FILE__, __LINE__, #exp, ""); \ + } \ +} while(0) +#define THAssertMsg(exp, ...) \ +do { \ + if (!(exp)) { \ + _THAssertionFailed(__FILE__, __LINE__, #exp, __VA_ARGS__); \ + } \ +} while(0) + +#define TH_CONCAT_STRING_2(x,y) TH_CONCAT_STRING_2_EXPAND(x,y) +#define TH_CONCAT_STRING_2_EXPAND(x,y) #x #y + +#define TH_CONCAT_STRING_3(x,y,z) TH_CONCAT_STRING_3_EXPAND(x,y,z) +#define TH_CONCAT_STRING_3_EXPAND(x,y,z) #x #y #z + +#define TH_CONCAT_STRING_4(x,y,z,w) TH_CONCAT_STRING_4_EXPAND(x,y,z,w) +#define TH_CONCAT_STRING_4_EXPAND(x,y,z,w) #x #y #z #w + +#define TH_CONCAT_2(x,y) TH_CONCAT_2_EXPAND(x,y) +#define TH_CONCAT_2_EXPAND(x,y) x ## y + +#define TH_CONCAT_3(x,y,z) TH_CONCAT_3_EXPAND(x,y,z) +#define TH_CONCAT_3_EXPAND(x,y,z) x ## y ## z + +#define TH_CONCAT_4_EXPAND(x,y,z,w) x ## y ## z ## w +#define TH_CONCAT_4(x,y,z,w) TH_CONCAT_4_EXPAND(x,y,z,w) + +#define THMin(X, Y) ((X) < (Y) ? (X) : (Y)) +#define THMax(X, Y) ((X) > (Y) ? (X) : (Y)) + +#if (defined(_MSC_VER) || defined(__MINGW32__)) +# define log1p(x) THLog1p(x) +#define snprintf _snprintf +#define popen _popen +#define pclose _pclose +#endif + +#endif diff --git a/modules/dnn/src/torch/torch_importer.cpp b/modules/dnn/src/torch/torch_importer.cpp new file mode 100644 index 00000000000..7ea55c82591 --- /dev/null +++ b/modules/dnn/src/torch/torch_importer.cpp @@ -0,0 +1,317 @@ +#include "../precomp.hpp" +#include +#include +#include +#include +#include + +namespace cv { +namespace dnn { + +#if ENABLE_TORCH_IMPORTER || 1 +#include "THDiskFile.h" + +enum LuaType +{ + TYPE_NIL = 0, + TYPE_NUMBER = 1, + TYPE_STRING = 2, + TYPE_TABLE = 3, + TYPE_TORCH = 4, + TYPE_BOOLEAN = 5, + TYPE_FUNCTION = 6, + TYPE_RECUR_FUNCTION = 8, + LEGACY_TYPE_RECUR_FUNCTION = 7 +}; + +template +static String toString(const T &v) +{ + std::ostringstream ss; + ss << v; + return ss.str(); +} + +static inline bool startsWith(const String &str, const char *substr) +{ + return str.find(substr) == 0; +} + +static inline bool endsWith(const String &str, const char *substr) +{ + return str.rfind(substr) == str.length() - strlen(substr); +} + + +struct TorchImporter : public ::cv::dnn::Importer +{ + THFile *file; + std::set readedIndexes; + std::map storages; + + TorchImporter(String filename, bool isBinary) + { + file = THDiskFile_new(filename.c_str(), "r", 0); + CV_Assert(file && THFile_isOpened(file)); + + if (isBinary) + THFile_binary(file); + else + THFile_ascii(file); + } + + /* Simple readers */ + + inline int readInt() + { + return THFile_readIntScalar(file); + } + + inline long readLong() + { + return THFile_readLongScalar(file); + } + + inline bool readBool() + { + return readInt(); + } + + inline double readDouble() + { + return THFile_readDoubleScalar(file); + } + + inline String readString() + { + int size = THFile_readIntScalar(file); + String str(size, '\0'); + THFile_readCharRaw(file, const_cast(str.c_str()), size); + return str; + } + + inline String readTorchClassName() + { + String version = readString(); + return startsWith(version, "V ") ? readString() : version; + } + + inline void readFunction() + { + readString(); + readObject(true); + } + + void readTable() + { + std::cout << "Skipping table\n"; + + int index = readInt(); + CV_Assert(readedIndexes.count(index) == 0); + readedIndexes.insert(index); + + int size = readInt(); + for (int i = 0; i < size; i++) + { + readObject(true); //key + readObject(true); //value + } + } + + /* Special readers */ + + static inline int parseTorchType(const String &str, const char *suffix, const char *prefix = "torch.") + { + if (startsWith(str, prefix) && endsWith(str, suffix)) + { + String typeStr = str.substr(strlen(prefix), str.length() - strlen(prefix) - strlen(suffix)); + + if (typeStr == "Double") + return CV_64F; + else if (typeStr == "Float") + return CV_32F; + else if (typeStr == "Byte") + return CV_8U; + else if (typeStr == "Char") + return CV_8S; + else if (typeStr == "Short") + return CV_16S; + else if (typeStr == "Int") + return CV_32S; + else + CV_Error(Error::StsNotImplemented, "Unknown type \"" + typeStr + "\" of torch class \"" + str + "\""); + } + + return -1; + } + + static int parseTensorType(const String &className) + { + return parseTorchType(className, "Tensor"); + } + + static int parseStorageType(const String &className) + { + return parseTorchType(className, "Storage"); + } + + void readTorchStorage(int index, int type = -1) + { + long size = readLong(); + Mat storageMat(1, size, type); + + THFile_readByteRaw(file, storageMat.data, size * CV_ELEM_SIZE(type)); + + storages.insert(std::make_pair(index, storageMat)); + readedIndexes.insert(index); + } + + Blob readTorchTensor(int typeTensor, bool skip = false) + { + int ndims = readInt(); + + AutoBuffer sizes(ndims); + AutoBuffer steps(ndims); + THFile_readLongRaw(file, sizes, ndims); + THFile_readLongRaw(file, sizes, ndims); + + long offset = readLong() - 1; + + //read Storage + int typeidx = readInt(); + std::cout << "stograge typeidx of tensor: " << typeidx << "\n"; + CV_Assert(typeidx == TYPE_TORCH || (typeidx == TYPE_NIL && ndims == 0)); + + if (typeidx == TYPE_NIL) + return Blob(); + + int index = readInt(); + if (readedIndexes.count(index) == 0) + { + int typeStorage = parseStorageType(readTorchClassName()); + CV_Assert(typeStorage >= 0 && typeTensor == typeStorage); + readTorchStorage(typeStorage, index); + } + + //allocate Blob + AutoBuffer isizes(ndims); + AutoBuffer ssteps(ndims); + + size_t stepExpected = 1; + for (int i = ndims - 1; i >= 0; i--) + { + isizes[i] = (int)sizes[i]; + ssteps[i] = (size_t)steps[i] * CV_ELEM_SIZE(typeTensor); + + stepExpected *= sizes[i]; + } + + if (skip) + return Blob(); + + Mat srcMat(ndims, (int*)isizes, typeTensor , storages[index].ptr(), (size_t*)ssteps); + int dstType = (typeTensor == CV_64F) ? CV_64F : CV_32F; + + Blob blob; + blob.create(BlobShape(ndims, isizes), dstType); + srcMat.convertTo(blob.getMatRef(), dstType); + + return blob; + } + + void readTorchObject(int index, bool skip = false) + { + String className = readTorchClassName(); + std::cout << "Class: " << className << std::endl; + + int type; + if ( (type = parseTensorType(className)) >= 0 ) //is Tensor + { + readTorchTensor(type); + return; + } + else if ( (type = parseStorageType(className)) >= 0 ) //is Storage + { + readTorchStorage(index, type); + } + else if (className == "nn.Sequential") + { + readObject(); + } + else if (className == "nn.Concat") + { + readObject(); + } + else if (className == "nn.SpatialConvolution") + { + readObject(); + } + else if (className == "nn.ReLU") + { + readObject(); + } + else + { + CV_Error(Error::StsNotImplemented, "Unsupported Torch class \"" + className +"\""); + } + } + + void readObject(bool skip = false) + { + int typeidx = readInt(); + std::cout << "typeidx: " << typeidx << "\n"; + + if (typeidx == TYPE_TORCH) + { + int index = readInt(); + + if (readedIndexes.count(index) == 0) + { + readTorchObject(index, skip); + readedIndexes.insert(index); + } + else + { + //CV_Error(Error::StsNotImplemented, ""); + //TBD + } + } + else if (typeidx == TYPE_NIL) + return; + else if (typeidx == TYPE_NUMBER) + readDouble(); + else if (typeidx == TYPE_BOOLEAN) + readBool(); + else if (typeidx == TYPE_STRING) + readString(); + else if (typeidx == TYPE_TABLE) + readTable(); + else + CV_Error(Error::StsNotImplemented, "Unsupported Lua type"); + } + + void populateNet(Net net) + { + THFile_seek(file, 0); + readedIndexes.clear(); + + readObject(); + } +}; + +CV_EXPORTS Ptr createTorchImporter(const String &filename, bool isBinary) +{ + return Ptr(new TorchImporter(filename, isBinary)); +} + +#else //ENABLE_TORCH_IMPORTER + +CV_EXPORTS Ptr createTorchImporter(const String&, bool) +{ + CV_Error(Error::StsNotImplemented, "Module was build without Torch importer"); + return Ptr(); +} + +#endif //ENABLE_TORCH_IMPORTER +} +} diff --git a/modules/dnn/test/test_torch_importer.cpp b/modules/dnn/test/test_torch_importer.cpp new file mode 100644 index 00000000000..92b7d88d0be --- /dev/null +++ b/modules/dnn/test/test_torch_importer.cpp @@ -0,0 +1,35 @@ +#if 1 || defined(ENABLE_TORCH_IMPORTER) && ENABLE_TORCH_IMPORTER +#include "test_precomp.hpp" + +namespace cvtest +{ + +using namespace std; +using namespace testing; +using namespace cv; +using namespace cv::dnn; + +static std::string getOpenCVExtraDir() +{ + return cvtest::TS::ptr()->get_data_path(); +} + +template +static std::string getTestFile(TStr filename) +{ + return (getOpenCVExtraDir() + "/dnn/") + filename; +} + +TEST(Torch_Importer, simple_read) +{ + Net net; + Ptr importer; + + ASSERT_NO_THROW( importer = createTorchImporter("/home/vitaliy/th/conv1.txt", false) ); + ASSERT_TRUE( importer != NULL ); + importer->populateNet(net); + //ASSERT_NO_THROW( importer->populateNet(net) ); +} + +} +#endif