diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000..d841c4e --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,18 @@ +ARG BASE_IMAGE_TAG + +FROM mcr.microsoft.com/devcontainers/cpp:$BASE_IMAGE_TAG + +ARG OCAML_VERSION +ARG OPAM_VERSION +ARG TORCH_VERSION + +RUN sudo apt-get update \ + && sudo apt-get satisfy -y "ocaml (>= $OCAML_VERSION)" "opam (>= $OPAM_VERSION)" \ + && rm -rf /var/lib/apt/lists/* + +RUN cd /usr/local \ + && sudo wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-$TORCH_VERSION%2Bcpu.zip \ + && sudo unzip libtorch-*.zip \ + && sudo rm libtorch-*.zip + +ENV CMAKE_PREFIX_PATH=/usr/local/libtorch diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..e316adb --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,22 @@ +{ + "build": { + "dockerfile": "Dockerfile", + "args": { + "BASE_IMAGE_TAG": "debian-11", + "OCAML_VERSION": "4", + "OPAM_VERSION": "2", + "TORCH_VERSION": "1.4.0" + } + }, + "customizations": { + "vscode": { + "extensions": [ + "ms-vscode.cpptools-extension-pack" + ] + } + }, + "features": { + "ghcr.io/julialang/devcontainer-features/julia:1": {} + }, + "postCreateCommand": "opam init --auto-setup" +} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..598ae79 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/.vscode diff --git a/build/README.md b/build/README.md deleted file mode 100644 index dbc1e6a..0000000 --- a/build/README.md +++ /dev/null @@ -1,18 +0,0 @@ -# Building the Wrapper - -The project can be built given that we can provide the paths to working Torch and CUDA/ CUDNN projects. The binaries can also be downloaded from the [official libtorch binaries](https://pytorch.org/get-started/locally/), which is what the wrapper is based on. This currently supports torch v1.4.0. - -```code -$ mkdir build && cd build - -# With a working torch install via Python (or similar): setting the CMAKE_PREFIX_PATH to point there might be sufficient -$ CMAKE_PREFIX_PATH=$HOME/.local/lib/python3.6/site-packages/torch\ - CUDNN_LIBRARY_PATH=$HOME/cuda/lib64\ - CUDNN_INCLUDE_PATH=$HOME/cuda/include\ - CUDNN_INCLUDE_DIR=$HOME/cuda/include\ - cmake .. - -$ cmake --build . -``` - -Post this, adding the path to the project via the `LD_LIBRARY_PATH` (and also the CUDNN) binary path might be needed. diff --git a/deps/README.md b/deps/README.md new file mode 100644 index 0000000..1dd7a17 --- /dev/null +++ b/deps/README.md @@ -0,0 +1,66 @@ +# C Wrapper + +Since Torch is a C++-library, a C wrapper is needed for Julia to interact with Torch. + +## Generating + +The C wrapper can be generated from the `Declarations.yaml`-file included with `Torch_jll`: +```sh +mkdir c_wrapper_generator/data +curl https://raw.githubusercontent.com/LaurentMazare/ocaml-torch/main/third_party/pytorch/Declarations-v1.4.0.yaml -o c_wrapper_generator/data/Declarations.yaml +``` + +The C wrapper can then be generated by building and running the (OCaml-based) C wrapper generator, e.g. by using the dev. container (which includes OCaml and OPAM): +```sh +cd c_wrapper_generator +opam install -y . --deps-only +opam exec -- dune build +_build/default/bin/main.exe +``` +or by using an OCaml-container: +```sh +docker run -it --rm -v `pwd`:/workspace -w /workspace ocaml/opam:debian-11-ocaml-4.12 bash -c ' + cd c_wrapper_generator + opam install -y . --deps-only + opam exec -- dune build + _build/default/bin/main.exe +' +``` + +## Building + +The C wrapper can be built given that we can provide the paths to Torch, CUDA, and CUDNN. Torch can be downloaded from the [official libtorch binaries](https://pytorch.org/get-started/locally/). + +### VS Code + +The C wrapper can be built from VS Code, provided that CMake is configured appropriately, e.g. by ensuring the following settings are in `.vscode/settings.json` (using the dev. container - or ensuring otherwise that `$CMAKE_PREFIX_PATH` is the path to `libtorch`): + +```json +{ + "cmake.buildDirectory": "${workspaceFolder}/deps/c_wrapper/build", + "cmake.buildEnvironment": { + "CMAKE_PREFIX_PATH": "$CMAKE_PREFIX_PATH" + }, + "cmake.sourceDirectory": "${workspaceFolder}/deps/c_wrapper" +} +``` + +### Manually + +The C wrapper can also be built manually, e.g. + +```sh +cd c_wrapper +mkdir build && cd build + +# With a working torch install via Python (or similar): setting the CMAKE_PREFIX_PATH to point there might be sufficient +CMAKE_PREFIX_PATH=$HOME/.local/lib/python3.6/site-packages/torch\ + CUDNN_LIBRARY_PATH=$HOME/cuda/lib64\ + CUDNN_INCLUDE_PATH=$HOME/cuda/include\ + CUDNN_INCLUDE_DIR=$HOME/cuda/include\ + cmake .. + +cmake --build . +``` + +Post this, adding the path to the project via the `LD_LIBRARY_PATH` (and also the CUDNN) binary path might be needed. diff --git a/deps/c_wrapper/.gitignore b/deps/c_wrapper/.gitignore new file mode 100644 index 0000000..378eac2 --- /dev/null +++ b/deps/c_wrapper/.gitignore @@ -0,0 +1 @@ +build diff --git a/build/CMakeLists.txt b/deps/c_wrapper/CMakeLists.txt similarity index 100% rename from build/CMakeLists.txt rename to deps/c_wrapper/CMakeLists.txt diff --git a/build/torch_api.cpp b/deps/c_wrapper/torch_api.cpp similarity index 100% rename from build/torch_api.cpp rename to deps/c_wrapper/torch_api.cpp diff --git a/build/torch_api.h b/deps/c_wrapper/torch_api.h similarity index 100% rename from build/torch_api.h rename to deps/c_wrapper/torch_api.h diff --git a/build/torch_api_generated.cpp.h b/deps/c_wrapper/torch_api_generated.cpp.h similarity index 100% rename from build/torch_api_generated.cpp.h rename to deps/c_wrapper/torch_api_generated.cpp.h diff --git a/build/torch_api_generated.h b/deps/c_wrapper/torch_api_generated.h similarity index 100% rename from build/torch_api_generated.h rename to deps/c_wrapper/torch_api_generated.h diff --git a/deps/c_wrapper_generator/.gitignore b/deps/c_wrapper_generator/.gitignore new file mode 100644 index 0000000..9cc2712 --- /dev/null +++ b/deps/c_wrapper_generator/.gitignore @@ -0,0 +1,2 @@ +_build +data diff --git a/deps/c_wrapper_generator/bin/dune b/deps/c_wrapper_generator/bin/dune new file mode 100644 index 0000000..869412a --- /dev/null +++ b/deps/c_wrapper_generator/bin/dune @@ -0,0 +1,4 @@ +(executable + (public_name wrapper_generator) + (name main) + (libraries base stdio yaml)) diff --git a/deps/c_wrapper_generator/bin/main.ml b/deps/c_wrapper_generator/bin/main.ml new file mode 100644 index 0000000..91b7d5a --- /dev/null +++ b/deps/c_wrapper_generator/bin/main.ml @@ -0,0 +1,323 @@ +(* Automatically generate the C++ -> C bindings. + This takes as input the Descriptions.yaml file that gets generated when + building PyTorch from source. + *) +open Base +open Stdio + +let excluded_functions = + Set.of_list + (module String) + [ "multi_margin_loss" + ; "multi_margin_loss_out" + ; "log_softmax_backward_data" + ; "softmax_backward_data" + ; "copy_" + ; "conv_transpose2d_backward_out" + ; "conv_transpose3d_backward_out" + ; "slow_conv_transpose2d_backward_out" + ; "slow_conv_transpose3d_backward_out" + ; "slow_conv3d_backward_out" + ; "normal" + ; "_cufft_set_plan_cache_max_size" + ; "_cufft_clear_plan_cache" + ; "backward" + ; "set_data" + ] + +let excluded_prefixes = [ "_"; "thnn_"; "th_" ] +let excluded_suffixes = [ "_forward"; "_forward_out" ] +let yaml_error yaml ~msg = Printf.failwithf "%s, %s" msg (Yaml.to_string_exn yaml) () + +let extract_bool = function + | `Bool b -> b + | `String "true" -> true + | `String "false" -> false + | yaml -> yaml_error yaml ~msg:"expected bool" + +let extract_list = function + | `A l -> l + | yaml -> yaml_error yaml ~msg:"expected list" + +let extract_map = function + | `O map -> Map.of_alist_exn (module String) map + | yaml -> yaml_error yaml ~msg:"expected map" + +let extract_string = function + | `String s -> s + (* The yaml spec for torch uses n which is converted to a bool. *) + | `Bool b -> if b then "y" else "n" + | `Float f -> Float.to_string f + | yaml -> yaml_error yaml ~msg:"expected string" + +module Func = struct + type arg_type = + | Bool + | Int64 + | Double + | Tensor + | TensorOption + | IntList + | TensorList + | TensorOptions + | Scalar + | ScalarType + | Device + + type arg = + { arg_name : string + ; arg_type : arg_type + ; default_value : string option + } + + type t = + { name : string + ; args : arg list + ; returns : (* number of tensors that are returned *) + [ `fixed of int | `dynamic ] + ; kind : [ `function_ | `method_ ] + } + + let arg_type_of_string str ~is_nullable = + match String.lowercase str with + | "bool" -> Some Bool + | "int64_t" -> Some Int64 + | "double" -> Some Double + | "booltensor" | "indextensor" | "tensor" -> + Some (if is_nullable then TensorOption else Tensor) + | "tensoroptions" -> Some TensorOptions + | "intarrayref" | "intlist" -> Some IntList + | "tensorlist" -> Some TensorList + | "device" -> Some Device + | "scalar" -> Some Scalar + | "scalartype" -> Some ScalarType + | _ -> None + + let c_typed_args_list t = + List.map t.args ~f:(fun { arg_name; arg_type; _ } -> + match arg_type with + | IntList -> Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name + | TensorList -> Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name + | TensorOptions -> Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name + | otherwise -> + let simple_type_cstring = + match otherwise with + | Bool -> "int" + | Int64 -> "int64_t" + | Double -> "double" + | Tensor -> "tensor" + | TensorOption -> "tensor" + | ScalarType -> "int" + | Device -> "int" + | Scalar -> "scalar" + | IntList | TensorList | TensorOptions -> assert false + in + Printf.sprintf "%s %s" simple_type_cstring arg_name) + |> String.concat ~sep:", " + + let c_args_list args = + List.map args ~f:(fun { arg_name; arg_type; _ } -> + match arg_type with + | Scalar | Tensor -> "*" ^ arg_name + | TensorOption -> Printf.sprintf "(%s ? *%s : torch::Tensor())" arg_name arg_name + | Bool -> "(bool)" ^ arg_name + | IntList -> + Printf.sprintf "torch::IntArrayRef(%s_data, %s_len)" arg_name arg_name + | TensorList -> + Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name arg_name + | TensorOptions -> + Printf.sprintf + "at::device(device_of_int(%s_device)).dtype(at::ScalarType(%s_kind))" + arg_name + arg_name + | ScalarType -> Printf.sprintf "torch::ScalarType(%s)" arg_name + | Device -> Printf.sprintf "device_of_int(%s)" arg_name + | _ -> arg_name) + |> String.concat ~sep:", " + + let c_call t = + match t.kind with + | `function_ -> Printf.sprintf "torch::%s(%s)" t.name (c_args_list t.args) + | `method_ -> + (match t.args with + | head :: tail -> + Printf.sprintf "%s->%s(%s)" head.arg_name t.name (c_args_list tail) + | [] -> + Printf.failwithf "Method calls should have at least one argument %s" t.name ()) +end + +exception Not_a_simple_arg + +let read_yaml filename = + let funcs = + (* Split the file to avoid Yaml.of_string_exn segfaulting. *) + In_channel.with_file filename ~f:In_channel.input_lines + |> List.group ~break:(fun _ l -> String.length l > 0 && Char.( = ) l.[0] '-') + |> List.concat_map ~f:(fun lines -> + Yaml.of_string_exn (String.concat lines ~sep:"\n") |> extract_list) + in + printf "Read %s, got %d functions.\n%!" filename (List.length funcs); + List.filter_map funcs ~f:(fun yaml -> + let map = extract_map yaml in + let name = Map.find_exn map "name" |> extract_string in + let deprecated = Map.find_exn map "deprecated" |> extract_bool in + let method_of = + Map.find_exn map "method_of" |> extract_list |> List.map ~f:extract_string + in + let arguments = Map.find_exn map "arguments" |> extract_list in + let returns = + let is_tensor returns = + let returns = extract_map returns in + let return_type = Map.find_exn returns "dynamic_type" |> extract_string in + String.( = ) return_type "Tensor" + || String.( = ) return_type "BoolTensor" + || String.( = ) return_type "IndexTensor" + in + let returns = Map.find_exn map "returns" |> extract_list in + if List.for_all returns ~f:is_tensor + then Some (`fixed (List.length returns)) + else ( + match returns with + | [ returns ] -> + let return_type = + Map.find_exn (extract_map returns) "dynamic_type" |> extract_string + in + if String.( = ) return_type "TensorList" then Some `dynamic else None + | [] | _ :: _ :: _ -> None) + in + let kind = + if List.exists method_of ~f:(String.( = ) "namespace") + then Some `function_ + else if List.exists method_of ~f:(String.( = ) "Tensor") + then Some `method_ + else None + in + if (not deprecated) + && (not + (List.exists excluded_prefixes ~f:(fun prefix -> + String.is_prefix name ~prefix))) + && (not + (List.exists excluded_suffixes ~f:(fun suffix -> + String.is_suffix name ~suffix))) + && not (Set.mem excluded_functions name) + then + Option.both returns kind + |> Option.bind ~f:(fun (returns, kind) -> + try + let args = + List.filter_map arguments ~f:(fun arg -> + let arg = extract_map arg in + let arg_name = Map.find_exn arg "name" |> extract_string in + let arg_type = Map.find_exn arg "dynamic_type" |> extract_string in + let is_nullable = + Map.find arg "is_nullable" + |> Option.value_map ~default:false ~f:extract_bool + in + let default_value = + Map.find arg "default" |> Option.map ~f:extract_string + in + match Func.arg_type_of_string arg_type ~is_nullable with + | Some Scalar when Option.is_some default_value && not is_nullable + -> None + | Some arg_type -> Some { Func.arg_name; arg_type; default_value } + | None -> + if Option.is_some default_value + then None + else raise Not_a_simple_arg) + in + Some { Func.name; args; returns; kind } + with + | Not_a_simple_arg -> None) + else None) + +let p out_channel s = + Printf.ksprintf + (fun line -> + Out_channel.output_string out_channel line; + Out_channel.output_char out_channel '\n') + s + +let write_cpp funcs filename = + Out_channel.with_file (filename ^ ".cpp.h") ~f:(fun out_cpp -> + Out_channel.with_file (filename ^ ".h") ~f:(fun out_h -> + let pc s = p out_cpp s in + let ph s = p out_h s in + pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!"; + pc ""; + ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!"; + ph ""; + Map.iteri funcs ~f:(fun ~key:exported_name ~data:func -> + let c_typed_args_list = Func.c_typed_args_list func in + match func.returns with + | `dynamic -> + pc "int atg_%s(tensor *out__, %s) {" exported_name c_typed_args_list; + pc " PROTECT("; + pc " auto outputs__ = %s;" (Func.c_call func); + (* the returned type is a C++ vector of tensors *) + pc " int sz = outputs__.size();"; + pc + " // torch::Tensor **out__ = (torch::Tensor**)malloc((sz + 1) * \ + sizeof(torch::Tensor*));"; + pc " for (int i = 0; i < sz; ++i)"; + pc " out__[i] = new torch::Tensor(outputs__[i]);"; + pc " out__[sz] = nullptr;"; + pc " // return out__;"; + pc " return 0;"; + pc ")"; + pc "return 1;"; + pc "}"; + pc ""; + ph "// tensor *atg_%s(%s);" exported_name c_typed_args_list; + ph "int atg_%s(tensor *, %s);" exported_name c_typed_args_list + | `fixed ntensors -> + pc "int atg_%s(tensor *out__, %s) {" exported_name c_typed_args_list; + pc " PROTECT("; + pc " auto outputs__ = %s;" (Func.c_call func); + if ntensors = 1 + then pc " out__[0] = new torch::Tensor(outputs__);" + else + for i = 0 to ntensors - 1 do + pc " out__[%d] = new torch::Tensor(std::get<%d>(outputs__));" i i + done; + pc " return 0;"; + pc ")"; + pc "return 1;"; + pc "}"; + pc ""; + ph "int atg_%s(tensor *, %s);" exported_name c_typed_args_list))) + +let methods = + let c name args = { Func.name; args; returns = `fixed 1; kind = `method_ } in + let ca arg_name arg_type = { Func.arg_name; arg_type; default_value = None } in + [ c "grad" [ ca "self" Tensor ] + ; c "set_requires_grad" [ ca "self" Tensor; ca "r" Bool ] + ; c "toType" [ ca "self" Tensor; ca "scalar_type" ScalarType ] + ; c "to" [ ca "self" Tensor; ca "device" Device ] + ] + +let run ~yaml_filename ~cpp_filename = + let funcs = read_yaml yaml_filename in + let funcs = methods @ funcs in + printf "Generating code for %d functions.\n%!" (List.length funcs); + (* Generate some unique names for overloaded functions. *) + let funcs = + List.map funcs ~f:(fun func -> String.lowercase func.name, func) + |> Map.of_alist_multi (module String) + |> Map.to_alist + |> List.concat_map ~f:(fun (name, funcs) -> + match funcs with + | [] -> assert false + | [ func ] -> [ name, func ] + | funcs -> + List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) -> + Int.compare (List.length f1.args) (List.length f2.args)) + |> List.mapi ~f:(fun i func -> + (if i = 0 then name else Printf.sprintf "%s%d" name i), func)) + |> Map.of_alist_exn (module String) + in + write_cpp funcs cpp_filename + +let () = + run + ~yaml_filename:"data/Declarations.yaml" + ~cpp_filename:"../c_wrapper/torch_api_generated" diff --git a/deps/c_wrapper_generator/dune-project b/deps/c_wrapper_generator/dune-project new file mode 100644 index 0000000..e187f54 --- /dev/null +++ b/deps/c_wrapper_generator/dune-project @@ -0,0 +1,26 @@ +(lang dune 3.4) + +(name wrapper_generator) + +(generate_opam_files true) + +(source + (github username/reponame)) + +(authors "Author Name") + +(maintainers "Maintainer Name") + +(license LICENSE) + +(documentation https://url/to/documentation) + +(package + (name wrapper_generator) + (synopsis "A short synopsis") + (description "A longer description") + (depends ocaml dune base stdio yaml) + (tags + (topics "to describe" your project))) + +; See the complete stanza docs at https://dune.readthedocs.io/en/stable/dune-files.html#dune-project diff --git a/deps/c_wrapper_generator/wrapper_generator.opam b/deps/c_wrapper_generator/wrapper_generator.opam new file mode 100644 index 0000000..b683032 --- /dev/null +++ b/deps/c_wrapper_generator/wrapper_generator.opam @@ -0,0 +1,34 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +synopsis: "A short synopsis" +description: "A longer description" +maintainer: ["Maintainer Name"] +authors: ["Author Name"] +license: "LICENSE" +tags: ["topics" "to describe" "your" "project"] +homepage: "https://github.com/username/reponame" +doc: "https://url/to/documentation" +bug-reports: "https://github.com/username/reponame/issues" +depends: [ + "ocaml" + "dune" {>= "3.4"} + "base" + "stdio" + "yaml" + "odoc" {with-doc} +] +build: [ + ["dune" "subst"] {dev} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] +] +dev-repo: "git+https://github.com/username/reponame.git"