Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2454a40

Browse files
committedJun 2, 2020
start of dotnet#146
1 parent 4aa0648 commit 2454a40

File tree

14 files changed

+197
-192
lines changed

14 files changed

+197
-192
lines changed
 

‎src/Native/LibTorchSharp/THSNN.cpp

+95-113
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,6 @@ void THSNN_Module_register_module(const NNModule module, const char* name, const
2626
);
2727
}
2828

29-
NNModule THSNN_Module_load(const char* location, const char* name)
30-
{
31-
CATCH_RETURN_NNModule(
32-
auto module = new torch::nn::Module();
33-
auto input = torch::serialize::InputArchive();
34-
35-
input.load_from(location);
36-
module->load(input);
37-
res = new std::shared_ptr<torch::nn::Module>(module);
38-
);
39-
}
40-
4129
int THSNN_Module_has_parameter(const NNModule module, const char* name)
4230
{
4331
CATCH_RETURN(int, 0, (*module)->named_parameters().contains(name));
@@ -135,212 +123,206 @@ class CustomModule : public torch::nn::Module
135123

136124
};
137125

138-
NNModule THSNN_custom_module(const char* name,
126+
NNModule THSNN_CustomModule_ctor(const char* name,
139127
const char** names,
140128
at::Tensor** parameters,
141129
const bool* require_grad,
142130
const int length,
143-
Tensor(*forward)(Tensor),
144-
NNAnyModule *outAsAnyModule)
131+
Tensor(*forward)(Tensor))
145132
{
146133
CATCH_RETURN_NNModule(
147134
auto mod = new CustomModule(name, names, parameters, require_grad, length, forward);
148-
149-
// Keep a boxed version of the module in case we add it to a Sequential later (the C++ templating means
150-
// a Module can only be boxed to AnyModule at the point its static type is known).
151-
if (outAsAnyModule != NULL)
152-
{
153-
auto modShared = new std::shared_ptr<CustomModule>(mod);
154-
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<CustomModule>(*modShared));
155-
*outAsAnyModule = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
156-
}
157135
res = new std::shared_ptr<torch::nn::Module>((torch::nn::Module*)mod);
158136
);
159137
}
160138

161-
NNModule THSNN_ReLU_ctor(bool inplace, NNAnyModule* outAsAnyModule)
139+
NNAnyModule THSNN_CustomModule_wrap(const NNModule module)
140+
{
141+
CATCH_RETURN_NNAnyModule(
142+
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<CustomModule>(*((*module)->as<CustomModule>())));
143+
res = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
144+
);
145+
}
146+
147+
NNModule THSNN_ReLU_ctor(bool inplace)
162148
{
163149
CATCH_RETURN_NNModule(
164150
auto opts = torch::nn::ReLUOptions(inplace);
165151
auto mod = std::make_shared<torch::nn::ReLUImpl>(opts);
166-
167-
// Keep a boxed version of the module in case we add it to a Sequential later (the C++ templating means
168-
// a Module can only be boxed to AnyModule at the point its static type is known).
169-
if (outAsAnyModule != NULL)
170-
{
171-
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::ReLUImpl>(*mod));
172-
*outAsAnyModule = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
173-
}
174-
175152
res = new std::shared_ptr<torch::nn::Module>(mod);
176153
);
177154
}
178155

156+
NNAnyModule THSNN_ReLU_wrap(const NNModule module)
157+
{
158+
CATCH_RETURN_NNAnyModule(
159+
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::ReLUImpl>(*((*module)->as<torch::nn::ReLU>())));
160+
res = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
161+
);
162+
}
163+
179164
Tensor THSNN_ReLU_forward(const NNModule module, const Tensor tensor)
180165
{
181166
CATCH_TENSOR((*module)->as<torch::nn::ReLU>()->forward(*tensor));
182167
}
183168

184-
NNModule THSNN_Dropout_ctor(double probability, NNAnyModule* outAsAnyModule)
169+
NNModule THSNN_Dropout_ctor(double probability)
185170
{
186171
CATCH_RETURN_NNModule(
187172
auto opts = torch::nn::DropoutOptions(probability);
188173
auto mod = std::make_shared<torch::nn::DropoutImpl>(opts);
189-
190-
// Keep a boxed version of the module in case we add it to a Sequential later (the C++ templating means
191-
// a Module can only be boxed to AnyModule at the point its static type is known).
192-
if (outAsAnyModule != NULL)
193-
{
194-
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::DropoutImpl>(*mod));
195-
*outAsAnyModule = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
196-
}
197-
198174
res = new std::shared_ptr<torch::nn::Module>(mod);
199175
);
200176
}
201177

178+
NNAnyModule THSNN_Dropout_wrap(const NNModule module)
179+
{
180+
CATCH_RETURN_NNAnyModule(
181+
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::DropoutImpl>(*((*module)->as<torch::nn::Dropout>())));
182+
res = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
183+
);
184+
}
185+
202186
Tensor THSNN_Dropout_forward(const NNModule module, const Tensor tensor)
203187
{
204188
CATCH_TENSOR((*module)->as<torch::nn::Dropout>()->forward(*tensor));
205189
}
206190

207-
NNModule THSNN_FeatureAlphaDropout_ctor(double probability, NNAnyModule* outAsAnyModule)
191+
NNModule THSNN_FeatureAlphaDropout_ctor(double probability)
208192
{
209193
CATCH_RETURN_NNModule(
210194
auto opts = torch::nn::FeatureAlphaDropoutOptions(probability);
211195
auto mod = std::make_shared<torch::nn::FeatureAlphaDropoutImpl>(opts);
212-
213-
// Keep a boxed version of the module in case we add it to a Sequential later (the C++ templating means
214-
// a Module can only be boxed to AnyModule at the point its static type is known).
215-
if (outAsAnyModule != NULL)
216-
{
217-
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::FeatureAlphaDropoutImpl>(*mod));
218-
*outAsAnyModule = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
219-
}
220196
res = new std::shared_ptr<torch::nn::Module>(mod);
221197
);
222198
}
223199

200+
NNAnyModule THSNN_FeatureAlphaDropout_wrap(const NNModule module)
201+
{
202+
CATCH_RETURN_NNAnyModule(
203+
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::FeatureAlphaDropoutImpl>(*((*module)->as<torch::nn::FeatureAlphaDropout>())));
204+
res = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
205+
);
206+
}
207+
224208
Tensor THSNN_FeatureAlphaDropout_forward(const NNModule module, const Tensor tensor)
225209
{
226210
CATCH_TENSOR((*module)->as<torch::nn::FeatureAlphaDropout>()->forward(*tensor));
227211
}
228212

229-
NNModule THSNN_LogSoftMax_ctor(int64_t dim, NNAnyModule* outAsAnyModule)
213+
NNModule THSNN_LogSoftmax_ctor(const int64_t dim)
230214
{
231215
CATCH_RETURN_NNModule(
232216
auto opts = torch::nn::LogSoftmaxOptions(dim);
233217
auto mod = std::make_shared<torch::nn::LogSoftmaxImpl>(opts);
234-
235-
// Keep a boxed version of the module in case we add it to a Sequential later (the C++ templating means
236-
// a Module can only be boxed to AnyModule at the point its static type is known).
237-
if (outAsAnyModule != NULL)
238-
{
239-
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::LogSoftmaxImpl>(*mod));
240-
*outAsAnyModule = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
241-
}
242218
res = new std::shared_ptr<torch::nn::Module>(mod);
243219
);
244220
}
245221

246-
Tensor THSNN_LogSoftMax_forward(const NNModule module, const Tensor tensor)
222+
NNAnyModule THSNN_LogSoftmax_wrap(const NNModule module)
223+
{
224+
CATCH_RETURN_NNAnyModule(
225+
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::LogSoftmaxImpl>(*((*module)->as<torch::nn::LogSoftmax>())));
226+
res = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
227+
);
228+
}
229+
230+
Tensor THSNN_LogSoftmax_forward(const NNModule module, const Tensor tensor)
247231
{
248232
CATCH_TENSOR((*module)->as<torch::nn::LogSoftmax>()->forward(*tensor));
249233
}
250234

251-
NNModule THSNN_AvgPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength,
252-
NNAnyModule* outAsAnyModule)
235+
NNModule THSNN_AvgPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength)
253236
{
254237
CATCH_RETURN_NNModule(
255238
auto opts = torch::nn::AvgPool2dOptions(at::ArrayRef<int64_t>(kernelSize, kernelSizeLength));
256239
if (stride)
257240
opts = opts.stride(at::ArrayRef<int64_t>(stride, strideLength));
258241
auto mod = std::make_shared<torch::nn::AvgPool2dImpl>(opts);
259-
260-
// Keep a boxed version of the module in case we add it to a Sequential later (the C++ templating means
261-
// a Module can only be boxed to AnyModule at the point its static type is known).
262-
if (outAsAnyModule != NULL)
263-
{
264-
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::AvgPool2dImpl>(*mod));
265-
*outAsAnyModule = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
266-
}
267242
res = new std::shared_ptr<torch::nn::Module>(mod);
268243
);
269244
}
270245

246+
NNAnyModule THSNN_AvgPool2d_wrap(const NNModule module)
247+
{
248+
CATCH_RETURN_NNAnyModule(
249+
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::AvgPool2dImpl>(*((*module)->as<torch::nn::AvgPool2d>())));
250+
res = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
251+
);
252+
}
253+
271254
Tensor THSNN_AvgPool2d_forward(const NNModule module, const Tensor tensor)
272255
{
273256
CATCH_TENSOR((*module)->as<torch::nn::AvgPool2d>()->forward(*tensor));
274257
}
275258

276-
NNModule THSNN_AdaptiveAvgPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength,
277-
NNAnyModule* outAsAnyModule)
259+
NNModule THSNN_AdaptiveAvgPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength)
278260
{
279261
CATCH_RETURN_NNModule(
280262
auto opts = torch::nn::AdaptiveAvgPool2dOptions(at::ArrayRef<int64_t>(kernelSize, kernelSizeLength));
281263
auto mod = std::make_shared<torch::nn::AdaptiveAvgPool2dImpl>(opts);
282-
283-
// Keep a boxed version of the module in case we add it to a Sequential later (the C++ templating means
284-
// a Module can only be boxed to AnyModule at the point its static type is known).
285-
if (outAsAnyModule != NULL)
286-
{
287-
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::AdaptiveAvgPool2dImpl>(*mod));
288-
*outAsAnyModule = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
289-
}
290264
res = new std::shared_ptr<torch::nn::Module>(mod);
291265
);
292266
}
293267

268+
NNAnyModule THSNN_AdaptiveAvgPool2d_wrap(const NNModule module)
269+
{
270+
CATCH_RETURN_NNAnyModule(
271+
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::AdaptiveAvgPool2dImpl>(*((*module)->as<torch::nn::AdaptiveAvgPool2d>())));
272+
res = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
273+
);
274+
}
275+
294276
Tensor THSNN_AdaptiveAvgPool2d_forward(const NNModule module, const Tensor tensor)
295277
{
296278
CATCH_TENSOR((*module)->as<torch::nn::AdaptiveAvgPool2d>()->forward(*tensor));
297279
}
298280

299-
NNModule THSNN_MaxPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength,
300-
NNAnyModule* outAsAnyModule)
281+
NNModule THSNN_MaxPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength)
301282
{
302283
CATCH_RETURN_NNModule(
303284
auto opts = torch::nn::MaxPool2dOptions(at::ArrayRef<int64_t>(kernelSize, kernelSizeLength));
304285
auto mod = std::make_shared<torch::nn::MaxPool2dImpl>(opts);
305286
if (stride)
306287
opts = opts.stride(at::ArrayRef<int64_t>(stride, strideLength));
307288

308-
// Keep a boxed version of the module in case we add it to a Sequential later (the C++ templating means
309-
// a Module can only be boxed to AnyModule at the point its static type is known).
310-
if (outAsAnyModule != NULL)
311-
{
312-
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::MaxPool2dImpl>(*mod));
313-
*outAsAnyModule = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
314-
}
315289
res = new std::shared_ptr<torch::nn::Module>(mod);
316290
)
317291
}
318292

293+
NNAnyModule THSNN_MaxPool2d_wrap(const NNModule module)
294+
{
295+
CATCH_RETURN_NNAnyModule(
296+
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::MaxPool2dImpl>(*((*module)->as<torch::nn::MaxPool2d>())));
297+
res = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
298+
);
299+
}
300+
319301
Tensor THSNN_MaxPool2d_forward(const NNModule module, const Tensor tensor)
320302
{
321303
CATCH_TENSOR((*module)->as<torch::nn::MaxPool2d>()->forward(*tensor));
322304
}
323305

324-
NNModule THSNN_Linear_ctor(const int64_t input_size, const int64_t output_size, const bool bias,
325-
NNAnyModule* outAsAnyModule)
306+
NNModule THSNN_Linear_ctor(const int64_t input_size, const int64_t output_size, const bool bias)
326307
{
327308
CATCH_RETURN_NNModule(
328309
auto opts = torch::nn::LinearOptions(input_size, output_size);
329310
opts = opts.bias(bias);
330311

331312
auto mod = std::make_shared<torch::nn::LinearImpl>(opts);
332-
333-
// Keep a boxed version of the module in case we add it to a Sequential later (the C++ templating means
334-
// a Module can only be boxed to AnyModule at the point its static type is known).
335-
if (outAsAnyModule != NULL)
336-
{
337-
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::LinearImpl>(*mod));
338-
*outAsAnyModule = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
339-
}
340313
res = new std::shared_ptr<torch::nn::Module>(mod);
341314
);
342315
}
343316

317+
NNAnyModule THSNN_Linear_wrap(const NNModule module)
318+
{
319+
//CATCH_RETURN_NNAnyModule(
320+
auto p = (*module)->as<torch::nn::Linear>();
321+
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::LinearImpl>(*p));
322+
return new std::shared_ptr<torch::nn::AnyModule>(wrapped);
323+
//);
324+
}
325+
344326
Tensor THSNN_Linear_forward(const NNModule module, const Tensor tensor)
345327
{
346328
CATCH_TENSOR((*module)->as<torch::nn::Linear>()->forward(*tensor));
@@ -371,25 +353,25 @@ void THSNN_Linear_set_weight(const NNModule module, const Tensor weight)
371353
}
372354

373355
NNModule THSNN_Conv2d_ctor(const int64_t inputChannel, const int64_t outputChannel,
374-
const int64_t kernelSize, const int64_t stride, const int64_t padding,
375-
NNAnyModule* outAsAnyModule)
356+
const int64_t kernelSize, const int64_t stride, const int64_t padding)
376357
{
377358
CATCH_RETURN_NNModule(
378359
auto opts = torch::nn::Conv2dOptions(inputChannel, outputChannel, kernelSize).stride(stride).padding(padding);
379360

380-
auto mod = std::make_shared<torch::nn::Conv2dImpl>(opts);
381-
382-
// Keep a boxed version of the module in case we add it to a Sequential later (the C++ templating means
383-
// a Module can only be boxed to AnyModule at the point its static type is known).
384-
if (outAsAnyModule != NULL)
385-
{
386-
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::Conv2dImpl>(*mod));
387-
*outAsAnyModule = new std::shared_ptr<torch::nn::AnyModule>(wrapped);
388-
}
361+
auto mod = std::make_shared<torch::nn::Conv2dImpl>(opts);
389362
res = new std::shared_ptr<torch::nn::Module>(mod);
390363
);
391364
}
392365

366+
NNAnyModule THSNN_Conv2d_wrap(const NNModule module)
367+
{
368+
//CATCH_RETURN_NNAnyModule(
369+
auto p = (*module)->as<torch::nn::Conv2d>();
370+
auto wrapped = std::make_shared<torch::nn::AnyModule>(torch::nn::ModuleHolder<torch::nn::Conv2dImpl>(*p));
371+
return new std::shared_ptr<torch::nn::AnyModule>(wrapped);
372+
//);
373+
}
374+
393375
Tensor THSNN_Conv2d_forward(const NNModule module, const Tensor tensor)
394376
{
395377
CATCH_TENSOR((*module)->as<torch::nn::Conv2d>()->forward(*tensor));

‎src/Native/LibTorchSharp/THSNN.h

+22-11
Original file line numberDiff line numberDiff line change
@@ -21,40 +21,51 @@ EXPORT_API(NNModule) THSNN_Module_child(const NNModule module, const int inde
2121
EXPORT_API(const char*) THSNN_Module_name(const NNModule module);
2222
EXPORT_API(void) THSNN_Module_zero_grad(const NNModule module);
2323
EXPORT_API(void) THSNN_Module_save(const NNModule module, const char * location);
24-
EXPORT_API(NNModule) THSNN_Module_load(const char * location, const char * name);
2524
EXPORT_API(void) THSNN_Module_register_module(const NNModule module, const char* name, const NNModule submodule);
2625
EXPORT_API(void) THSNN_Module_dispose(const NNModule module);
2726

2827
EXPORT_API(void) THSNN_AnyModule_dispose(const NNAnyModule module);
29-
//EXPORT_API(NNModule) THSNN_AnyModule_get(const NNAnyModule module);
3028

31-
EXPORT_API(NNModule) THSNN_custom_module(const char* name, const char** names, at::Tensor** parameters, const bool* require_grad, const int length, Tensor(*forward)(Tensor), NNAnyModule* outAsAnyModule);
32-
EXPORT_API(NNModule) THSNN_AdaptiveAvgPool2d_ctor(const int64_t* sizes, const int length, NNAnyModule* outAsAnyModule);
29+
EXPORT_API(NNModule) THSNN_CustomModule_ctor(const char* name, const char** names, at::Tensor** parameters, const bool* require_grad, const int length, Tensor(*forward)(Tensor));
30+
EXPORT_API(NNAnyModule) THSNN_CustomModule_wrap(const NNModule module);
31+
EXPORT_API(NNModule) THSNN_AdaptiveAvgPool2d_ctor(const int64_t* sizes, const int length);
32+
EXPORT_API(NNAnyModule) THSNN_AdaptiveAvgPool2d_wrap(const NNModule module);
3333
EXPORT_API(Tensor) THSNN_AdaptiveAvgPool2d_forward(const NNModule module, const Tensor tensor);
3434

35-
EXPORT_API(NNModule) THSNN_AvgPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, NNAnyModule* outAsAnyModule);
35+
EXPORT_API(NNModule) THSNN_AvgPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength);
36+
EXPORT_API(NNAnyModule) THSNN_AvgPool2d_wrap(const NNModule module);
3637
EXPORT_API(Tensor) THSNN_AvgPool2d_forward(const NNModule module, const Tensor tensor);
3738

38-
EXPORT_API(NNModule) THSNN_Conv2d_ctor(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelSize, const int64_t stride, const int64_t padding, NNAnyModule* outAsAnyModule);
39+
EXPORT_API(NNModule) THSNN_Conv2d_ctor(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelSize, const int64_t stride, const int64_t padding);
40+
EXPORT_API(NNAnyModule) THSNN_Conv2d_wrap(const NNModule module);
3941
EXPORT_API(Tensor) THSNN_Conv2d_forward(const NNModule module, const Tensor tensor);
4042

41-
EXPORT_API(NNModule) THSNN_MaxPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, NNAnyModule* outAsAnyModule);
43+
EXPORT_API(NNModule) THSNN_MaxPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength);
44+
EXPORT_API(NNAnyModule) THSNN_MaxPool2d_wrap(const NNModule module);
4245
EXPORT_API(Tensor) THSNN_MaxPool2d_forward(const NNModule module, const Tensor tensor);
4346

44-
EXPORT_API(NNModule) THSNN_Dropout_ctor(double probability, NNAnyModule* outAsAnyModule);
47+
EXPORT_API(NNModule) THSNN_Dropout_ctor(double probability);
48+
EXPORT_API(NNAnyModule) THSNN_Dropout_wrap(const NNModule module);
4549
EXPORT_API(Tensor) THSNN_Dropout_forward(const NNModule module, const Tensor tensor);
4650

47-
EXPORT_API(NNModule) THSNN_FeatureAlphaDropout_ctor(double probability, NNAnyModule* outAsAnyModule);
51+
EXPORT_API(NNModule) THSNN_LogSoftmax_ctor(const int64_t dim);
52+
EXPORT_API(NNAnyModule) THSNN_LogSoftmax_wrap(const NNModule module);
53+
EXPORT_API(Tensor) THSNN_LogSoftmax_forward(const NNModule module, const Tensor tensor);
54+
55+
EXPORT_API(NNModule) THSNN_FeatureAlphaDropout_ctor(double probability);
56+
EXPORT_API(NNAnyModule) THSNN_FeatureAlphaDropout_wrap(const NNModule module);
4857
EXPORT_API(Tensor) THSNN_FeatureAlphaDropout_forward(const NNModule module, const Tensor tensor);
4958

50-
EXPORT_API(NNModule) THSNN_Linear_ctor(const int64_t input_size, const int64_t output_size, const bool with_bias, NNAnyModule* outAsAnyModule);
59+
EXPORT_API(NNModule) THSNN_Linear_ctor(const int64_t input_size, const int64_t output_size, const bool with_bias);
60+
EXPORT_API(NNAnyModule) THSNN_Linear_wrap(const NNModule module);
5161
EXPORT_API(Tensor) THSNN_Linear_forward(const NNModule module, const Tensor tensor);
5262
EXPORT_API(Tensor) THSNN_Linear_bias(const NNModule module);
5363
EXPORT_API(void) THSNN_Linear_set_bias(const NNModule module, const Tensor tensor);
5464
EXPORT_API(Tensor) THSNN_Linear_weight(const NNModule module);
5565
EXPORT_API(void) THSNN_Linear_set_weight(const NNModule module, const Tensor tensor);
5666

57-
EXPORT_API(NNModule) THSNN_ReLU_ctor(bool inplace, NNAnyModule* outAsAnyModule);
67+
EXPORT_API(NNModule) THSNN_ReLU_ctor(bool inplace);
68+
EXPORT_API(NNAnyModule) THSNN_ReLU_wrap(const NNModule module);
5869
EXPORT_API(Tensor) THSNN_ReLU_forward(const NNModule module, const Tensor tensor);
5970

6071
EXPORT_API(NNModule) THSNN_Sequential_ctor();

‎src/Native/LibTorchSharp/Utils.h

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ typedef std::shared_ptr<torch::optim::Optimizer> * Optimizer;
3636

3737
#define CATCH_RETURN(ty, dflt, expr) CATCH_RETURN_RES(ty, dflt, res = expr)
3838
#define CATCH_RETURN_NNModule(stmt) CATCH_RETURN_RES(NNModule, NULL, stmt)
39+
#define CATCH_RETURN_NNAnyModule(stmt) CATCH_RETURN_RES(NNAnyModule, NULL, stmt)
3940
#define CATCH_RETURN_Tensor(stmt) CATCH_RETURN_RES(Tensor, NULL, stmt)
4041

4142
// Return undefined tensors as NULL to C#

‎src/TorchSharp/NN/AdaptiveAvgPool2D.cs

+8-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ namespace TorchSharp.NN
1010
/// </summary>
1111
public class AdaptiveAvgPool2D : Module
1212
{
13-
internal AdaptiveAvgPool2D (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle)
13+
14+
[DllImport("LibTorchSharp")]
15+
extern static IntPtr THSNN_AdaptiveAvgPool2d_wrap(IntPtr handle);
16+
internal AdaptiveAvgPool2D (IntPtr handle) : base (handle, THSNN_AdaptiveAvgPool2d_wrap(handle))
1417
{
1518
}
1619

@@ -27,15 +30,16 @@ public TorchTensor Forward (TorchTensor tensor)
2730
public static partial class Modules
2831
{
2932
[DllImport ("LibTorchSharp")]
30-
extern static IntPtr THSNN_AdaptiveAvgPool2d_ctor (IntPtr psizes, int length, out IntPtr pBoxedModule);
33+
extern static IntPtr THSNN_AdaptiveAvgPool2d_ctor (IntPtr psizes, int length);
3134

3235
static public AdaptiveAvgPool2D AdaptiveAvgPool2D (long[] kernelSize)
3336
{
3437
unsafe {
3538
fixed (long* pkernelSize = kernelSize) {
36-
var handle = THSNN_AdaptiveAvgPool2d_ctor ((IntPtr)pkernelSize, kernelSize.Length, out var boxedHandle);
39+
var handle = THSNN_AdaptiveAvgPool2d_ctor ((IntPtr)pkernelSize, kernelSize.Length);
3740
Torch.CheckForErrors ();
38-
return new AdaptiveAvgPool2D (handle, boxedHandle);
41+
42+
return new AdaptiveAvgPool2D (handle);
3943
}
4044
}
4145
}

‎src/TorchSharp/NN/AvgPool2D.cs

+8-4
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,14 @@ namespace TorchSharp.NN
1010
/// </summary>
1111
public class AvgPool2D : Module
1212
{
13-
internal AvgPool2D (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle)
13+
[DllImport("LibTorchSharp")]
14+
extern static IntPtr THSNN_AvgPool2d_wrap(IntPtr handle);
15+
16+
internal AvgPool2D (IntPtr handle) : base (handle, THSNN_AvgPool2d_wrap(handle))
1417
{
1518
}
1619

20+
1721
[DllImport ("LibTorchSharp")]
1822
private static extern IntPtr THSNN_AvgPool2d_forward (IntPtr module, IntPtr tensor);
1923

@@ -27,15 +31,15 @@ public TorchTensor Forward (TorchTensor tensor)
2731
public static partial class Modules
2832
{
2933
[DllImport ("LibTorchSharp")]
30-
extern static IntPtr THSNN_AvgPool2d_ctor (IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, out IntPtr pBoxedModule);
34+
extern static IntPtr THSNN_AvgPool2d_ctor (IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength);
3135

3236
static public AvgPool2D AvgPool2D (long[] kernelSize, long[] strides = null)
3337
{
3438
unsafe {
3539
fixed (long* pkernelSize = kernelSize, pstrides = strides) {
36-
var handle = THSNN_AvgPool2d_ctor ((IntPtr)pkernelSize, kernelSize.Length, (IntPtr)pstrides, (strides == null ? 0 : strides.Length), out var boxedHandle);
40+
var handle = THSNN_AvgPool2d_ctor ((IntPtr)pkernelSize, kernelSize.Length, (IntPtr)pstrides, (strides == null ? 0 : strides.Length));
3741
Torch.CheckForErrors ();
38-
return new AvgPool2D (handle, boxedHandle);
42+
return new AvgPool2D (handle);
3943
}
4044
}
4145
}

‎src/TorchSharp/NN/Conv2D.cs

+7-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ namespace TorchSharp.NN
77
{
88
public class Conv2D : Module
99
{
10-
internal Conv2D (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle) { }
10+
[DllImport("LibTorchSharp")]
11+
extern static IntPtr THSNN_Conv2d_wrap(IntPtr handle);
12+
13+
internal Conv2D (IntPtr handle) : base (handle, THSNN_Conv2d_wrap(handle)) { }
1114

1215
[DllImport ("LibTorchSharp")]
1316
private static extern IntPtr THSNN_Conv2d_forward (Module.HType module, IntPtr tensor);
@@ -22,13 +25,13 @@ public TorchTensor Forward (TorchTensor tensor)
2225
public static partial class Modules
2326
{
2427
[DllImport ("LibTorchSharp")]
25-
private static extern IntPtr THSNN_Conv2d_ctor (long inputChannel, long outputChannel, long kernelSize, long stride, long padding, out IntPtr pBoxedModule);
28+
private static extern IntPtr THSNN_Conv2d_ctor (long inputChannel, long outputChannel, long kernelSize, long stride, long padding);
2629

2730
static public Conv2D Conv2D (long inputChannel, long outputChannel, long kernelSize, long stride = 1, long padding = 0)
2831
{
29-
var res = THSNN_Conv2d_ctor (inputChannel, outputChannel, kernelSize, stride, padding, out var boxedHandle);
32+
var res = THSNN_Conv2d_ctor (inputChannel, outputChannel, kernelSize, stride, padding);
3033
Torch.CheckForErrors ();
31-
return new Conv2D (res, boxedHandle);
34+
return new Conv2D (res);
3235
}
3336
}
3437
public static partial class Functions

‎src/TorchSharp/NN/Dropout.cs

+7-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ namespace TorchSharp.NN
1010
/// </summary>
1111
public class Dropout : Module
1212
{
13-
internal Dropout (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle) { }
13+
[DllImport("LibTorchSharp")]
14+
extern static IntPtr THSNN_Dropout_wrap(IntPtr handle);
15+
16+
internal Dropout (IntPtr handle) : base (handle, THSNN_Dropout_wrap(handle)) { }
1417

1518
[DllImport ("LibTorchSharp")]
1619
private static extern IntPtr THSNN_Dropout_forward (Module.HType module, IntPtr tensor);
@@ -25,13 +28,13 @@ public TorchTensor Forward (TorchTensor tensor)
2528
public static partial class Modules
2629
{
2730
[DllImport ("LibTorchSharp")]
28-
extern static IntPtr THSNN_Dropout_ctor (double probability, out IntPtr pBoxedModule);
31+
extern static IntPtr THSNN_Dropout_ctor (double probability);
2932

3033
static public Dropout Dropout (double probability = 0.5)
3134
{
32-
var handle = THSNN_Dropout_ctor (probability, out var boxedHandle);
35+
var handle = THSNN_Dropout_ctor (probability);
3336
Torch.CheckForErrors ();
34-
return new Dropout (handle, boxedHandle);
37+
return new Dropout (handle);
3538
}
3639
}
3740

‎src/TorchSharp/NN/FeatureDropout.cs

+7-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ namespace TorchSharp.NN
1010
/// </summary>
1111
public class FeatureAlphaDropout : Module
1212
{
13-
internal FeatureAlphaDropout (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle)
13+
[DllImport("LibTorchSharp")]
14+
extern static IntPtr THSNN_FeatureAlphaDropout_wrap(IntPtr handle);
15+
16+
internal FeatureAlphaDropout (IntPtr handle) : base (handle, THSNN_FeatureAlphaDropout_wrap(handle))
1417
{
1518
}
1619

@@ -27,13 +30,13 @@ public TorchTensor Forward (TorchTensor tensor)
2730
public static partial class Modules
2831
{
2932
[DllImport ("LibTorchSharp")]
30-
extern static IntPtr THSNN_FeatureAlphaDropout_ctor (double probability, out IntPtr pBoxedModule);
33+
extern static IntPtr THSNN_FeatureAlphaDropout_ctor (double probability);
3134

3235
static public FeatureAlphaDropout FeatureAlphaDropout (double probability = 0.5)
3336
{
34-
var handle = THSNN_FeatureAlphaDropout_ctor (probability, out var boxedHandle);
37+
var handle = THSNN_FeatureAlphaDropout_ctor (probability);
3538
Torch.CheckForErrors ();
36-
return new FeatureAlphaDropout (handle, boxedHandle);
39+
return new FeatureAlphaDropout (handle);
3740
}
3841
}
3942

‎src/TorchSharp/NN/Linear.cs

+6-10
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,10 @@ namespace TorchSharp.NN
99
{
1010
public class Linear : Module
1111
{
12-
internal Linear (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle) { }
12+
[DllImport("LibTorchSharp")]
13+
extern static IntPtr THSNN_Linear_wrap(IntPtr handle);
1314

14-
public new static Linear Load (String modelPath)
15-
{
16-
var res = Module.Load (modelPath);
17-
Torch.CheckForErrors ();
18-
return new Linear (res.handle.DangerousGetHandle(), IntPtr.Zero);
19-
}
15+
internal Linear (IntPtr handle) : base (handle, THSNN_Linear_wrap(handle)) { }
2016

2117
[DllImport ("LibTorchSharp")]
2218
extern static IntPtr THSNN_Linear_forward (Module.HType module, IntPtr tensor);
@@ -63,13 +59,13 @@ public TorchTensor Weight {
6359
public static partial class Modules
6460
{
6561
[DllImport ("LibTorchSharp")]
66-
private static extern IntPtr THSNN_Linear_ctor (long input_size, long output_size, bool bias, out IntPtr pBoxedModule);
62+
private static extern IntPtr THSNN_Linear_ctor (long input_size, long output_size, bool bias);
6763

6864
static public Linear Linear (long inputSize, long outputSize, bool hasBias = true)
6965
{
70-
var res = THSNN_Linear_ctor (inputSize, outputSize, hasBias, out var boxedHandle);
66+
var res = THSNN_Linear_ctor (inputSize, outputSize, hasBias);
7167
Torch.CheckForErrors ();
72-
return new Linear (res, boxedHandle);
68+
return new Linear (res);
7369
}
7470
}
7571
public static partial class Functions

‎src/TorchSharp/NN/LogSoftMax.cs

+9-7
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,32 @@ namespace TorchSharp.NN
1010
/// </summary>
1111
public class LogSoftMax : Module
1212
{
13-
internal LogSoftMax (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle)
13+
[DllImport("LibTorchSharp")]
14+
extern static IntPtr THSNN_LogSoftmax_wrap(IntPtr handle);
15+
16+
internal LogSoftMax (IntPtr handle) : base (handle, THSNN_LogSoftmax_wrap(handle))
1417
{
1518
}
1619

1720
[DllImport ("LibTorchSharp")]
18-
private static extern IntPtr THSNN_LogSoftMax_forward (Module.HType handle, IntPtr tensor);
21+
private static extern IntPtr THSNN_LogSoftmax_forward (Module.HType handle, IntPtr tensor);
1922

2023
public TorchTensor Forward (TorchTensor tensor)
2124
{
22-
var res = THSNN_LogSoftMax_forward (handle, tensor.Handle);
25+
var res = THSNN_LogSoftmax_forward (handle, tensor.Handle);
2326
Torch.CheckForErrors ();
2427
return new TorchTensor (res);
2528
}
2629
}
2730
public static partial class Modules
2831
{
2932
[DllImport ("LibTorchSharp")]
30-
extern static IntPtr THSNN_LogSoftMax_ctor (long dimension, out IntPtr pBoxedModule);
33+
extern static IntPtr THSNN_LogSoftmax_ctor (long dimension);
3134

3235
static public LogSoftMax LogSoftMax (long dimension)
3336
{
34-
var handle = THSNN_LogSoftMax_ctor (dimension, out var boxedHandle);
35-
Torch.CheckForErrors ();
36-
return new LogSoftMax (handle, boxedHandle);
37+
var handle = THSNN_LogSoftmax_ctor (dimension);
38+
return new LogSoftMax (handle);
3739
}
3840
}
3941

‎src/TorchSharp/NN/MaxPool2D.cs

+7-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ namespace TorchSharp.NN
1010
/// </summary>
1111
public class MaxPool2D : Module
1212
{
13-
internal MaxPool2D (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle)
13+
[DllImport("LibTorchSharp")]
14+
extern static IntPtr THSNN_MaxPool2d_wrap(IntPtr handle);
15+
16+
internal MaxPool2D (IntPtr handle) : base (handle, THSNN_MaxPool2d_wrap(handle))
1417
{
1518
}
1619

@@ -27,15 +30,15 @@ public TorchTensor Forward (TorchTensor tensor)
2730
public static partial class Modules
2831
{
2932
[DllImport ("LibTorchSharp")]
30-
extern static IntPtr THSNN_MaxPool2d_ctor (IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, out IntPtr pBoxedModule);
33+
extern static IntPtr THSNN_MaxPool2d_ctor (IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength);
3134

3235
static public MaxPool2D MaxPool2D (long[] kernelSize, long[] strides = null)
3336
{
3437
unsafe {
3538
fixed (long* pkernelSize = kernelSize, pstrides = strides) {
36-
var handle = THSNN_MaxPool2d_ctor ((IntPtr)pkernelSize, kernelSize.Length, (IntPtr)pstrides, (strides == null ? 0 : strides.Length), out var boxedHandle);
39+
var handle = THSNN_MaxPool2d_ctor ((IntPtr)pkernelSize, kernelSize.Length, (IntPtr)pstrides, (strides == null ? 0 : strides.Length));
3740
Torch.CheckForErrors ();
38-
return new MaxPool2D (handle, boxedHandle);
41+
return new MaxPool2D (handle);
3942
}
4043
}
4144
}

‎src/TorchSharp/NN/Module.cs

+7-11
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,7 @@ protected void Dispose (bool disposing)
9191
handle.SetHandleAsInvalid ();
9292
}
9393
}
94-
[DllImport("LibTorchSharp")]
95-
extern static IntPtr THSNN_Module_load([MarshalAs(UnmanagedType.LPStr)] string location);
9694

97-
public static Module Load(String location)
98-
{
99-
var handle = THSNN_Module_load (location);
100-
Torch.CheckForErrors ();
101-
return new Module (handle, IntPtr.Zero);
102-
}
10395

10496
[DllImport ("LibTorchSharp")]
10597
extern static void THSNN_Module_save (HType handle, [MarshalAs(UnmanagedType.LPStr)] string location);
@@ -320,9 +312,12 @@ public abstract class CustomModule : Module
320312
private delegate IntPtr ForwardFunctionC (IntPtr tensor);
321313

322314
[DllImport ("LibTorchSharp")]
323-
private static extern IntPtr THSNN_custom_module([MarshalAs(UnmanagedType.LPStr)] string name,
315+
private static extern IntPtr THSNN_CustomModule_ctor([MarshalAs(UnmanagedType.LPStr)] string name,
324316
IntPtr names, IntPtr parameters, IntPtr require_grad,
325-
int length, ForwardFunctionC forward, out IntPtr pBoxedModule);
317+
int length, ForwardFunctionC forward);
318+
319+
[DllImport("LibTorchSharp")]
320+
extern static IntPtr THSNN_CustomModule_wrap(IntPtr handle);
326321

327322
protected CustomModule (string name, params Parameter[] parameters) : base (IntPtr.Zero, IntPtr.Zero)
328323
{
@@ -339,10 +334,11 @@ protected CustomModule (string name, params Parameter[] parameters) : base (IntP
339334
var gparray = wGradPinned.CreateArray (withGrads);
340335

341336
ForwardFunctionC forwardNative = t => (Forward (new TorchTensor (t)).Handle);
342-
var res = THSNN_custom_module (name, nparray, pparray, gparray, names.Length, forwardNative, out var boxedHandle);
337+
var res = THSNN_CustomModule_ctor (name, nparray, pparray, gparray, names.Length, forwardNative);
343338
Torch.CheckForErrors ();
344339
this.handle = new HType (res, true);
345340
this.forwardNative = forwardNative;
341+
var boxedHandle = THSNN_CustomModule_wrap(res);
346342
this.boxedModule = new BoxedModule(boxedHandle);
347343
}
348344

‎src/TorchSharp/NN/ReLu.cs

+7-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ namespace TorchSharp.NN
1010
/// </summary>
1111
public class ReLU : Module
1212
{
13-
internal ReLU (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle) { }
13+
[DllImport("LibTorchSharp")]
14+
extern static IntPtr THSNN_ReLU_wrap(IntPtr handle);
15+
16+
internal ReLU (IntPtr handle) : base (handle, THSNN_ReLU_wrap(handle)) { }
1417

1518
[DllImport ("LibTorchSharp")]
1619
private static extern IntPtr THSNN_ReLU_forward (Module.HType module, IntPtr tensor);
@@ -31,13 +34,13 @@ public override string GetName ()
3134
public static partial class Modules
3235
{
3336
[DllImport ("LibTorchSharp")]
34-
extern static IntPtr THSNN_ReLU_ctor (bool inplace, out IntPtr pBoxedModule);
37+
extern static IntPtr THSNN_ReLU_ctor (bool inplace);
3538

3639
static public ReLU Relu (bool inPlace = false)
3740
{
38-
var handle = THSNN_ReLU_ctor (inPlace, out var boxedHandle);
41+
var handle = THSNN_ReLU_ctor (inPlace);
3942
Torch.CheckForErrors ();
40-
return new ReLU (handle, boxedHandle);
43+
return new ReLU (handle);
4144
}
4245
}
4346
public static partial class Functions

‎test/TorchSharpTest/TorchSharp.cs

+6-12
Original file line numberDiff line numberDiff line change
@@ -967,38 +967,32 @@ public void TestSubInPlace()
967967
}
968968
}
969969
[Fact]
970-
public void TestSaveLoadLinear()
970+
public void TestSaveLinear()
971971
{
972972
if (File.Exists (".model.ts")) File.Delete (".model.ts");
973973
var linear = Linear(100, 10, true);
974974
linear.Save(".model.ts");
975-
var loadedLinear = NN.Linear.Load(".model.ts");
976-
File.Delete(".model.ts");
977-
Assert.NotNull(loadedLinear);
975+
//var loadedLinear = NN.Linear.Load(".model.ts");
976+
//File.Delete(".model.ts");
977+
//Assert.NotNull(loadedLinear);
978978
}
979979

980980
[Fact]
981-
public void TestSaveLoadConv2D()
981+
public void TestSaveConv2D()
982982
{
983983
if (File.Exists (".model.ts")) File.Delete (".model.ts");
984984
var conv = Conv2D(100, 10, 5);
985985
conv.Save(".model.ts");
986-
var loaded = NN.Conv2D.Load(".model.ts");
987-
File.Delete(".model.ts");
988-
Assert.NotNull(loaded);
989986
}
990987

991988
[Fact]
992-
public void TestSaveLoadSequence()
989+
public void TestSaveSequence()
993990
{
994991
if (File.Exists (".model-list.txt")) File.Delete (".model-list.txt");
995992
var lin1 = Linear(100, 10, true);
996993
var lin2 = Linear(10, 5, true);
997994
var seq = Sequential(("lin1", lin1), ("lin2", lin2));
998995
seq.Save(".model-list.txt");
999-
var loaded = NN.Sequential.Load(".model-list.txt");
1000-
File.Delete("model-list.txt");
1001-
Assert.NotNull(loaded);
1002996
}
1003997

1004998
[Fact]

0 commit comments

Comments
 (0)
Please sign in to comment.