Skip to content

Commit

Permalink
Add CUDA support for Windows (Macoron#61)
Browse files Browse the repository at this point in the history
* Add support for params context

* Add system info print function

* Changed script to build CUDA

* Update build_cpp.yml

* Update build_cpp.yml

* Update build_cpp.yml

* Added cuda dll

* Minor changes

* Fix older unity

* And this

* Update README.md
  • Loading branch information
Macoron authored Nov 21, 2023
1 parent f2f432e commit d9b53a7
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 22 deletions.
89 changes: 89 additions & 0 deletions Editor/WhisperProjectSettings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
using System;
using System.Collections.Generic;
using System.Linq;
using UnityEditor;
using UnityEditor.Build;
using UnityEngine;


static class WhisperProjectSettingsProvider
{
[SettingsProvider]
public static SettingsProvider CreateMyCustomSettingsProvider()
{
var provider = new SettingsProvider("Project/WhisperSettings", SettingsScope.Project)
{
label = "Whisper",
guiHandler = (searchContext) =>
{
CudaEnabled = EditorGUILayout.Toggle("Enable CUDA", CudaEnabled);
},

keywords = new HashSet<string>(new[] { "CUDA", "cuBLAS" })
};

return provider;
}

public static bool CudaEnabled
{
get
{
#if WHISPER_CUDA
return true;
#else
return false;
#endif
}
set
{
if (value == CudaEnabled)
return;

string[] newDefines;
var defines = GetStandaloneDefines();

if (value)
{
if (defines.Contains("WHISPER_CUDA"))
return;

newDefines = defines.Append("WHISPER_CUDA").ToArray();
}
else
{
if (!defines.Contains("WHISPER_CUDA"))
return;

newDefines = defines.Where(x => x != "WHISPER_CUDA").ToArray();
}

SetStandaloneDefines(newDefines);
}
}

// This is for older Unity compability
private static string[] GetStandaloneDefines()
{
string[] defines;

#if UNITY_2021_3_OR_NEWER
PlayerSettings.GetScriptingDefineSymbols(NamedBuildTarget.Standalone, out defines);
#else
var definesStr = PlayerSettings.GetScriptingDefineSymbolsForGroup(BuildTargetGroup.Standalone);
defines = definesStr.Split(';');
#endif

return defines;
}

private static void SetStandaloneDefines(string[] newDefines)
{
#if UNITY_2021_3_OR_NEWER
PlayerSettings.SetScriptingDefineSymbols(NamedBuildTarget.Standalone, newDefines);
#else
var definesStr = string.Join(";", newDefines);
PlayerSettings.SetScriptingDefineSymbolsForGroup(BuildTargetGroup.Standalone, definesStr);
#endif
}
}
11 changes: 11 additions & 0 deletions Editor/WhisperProjectSettings.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Plugins/Windows/libwhisper.dll.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file added Plugins/Windows/libwhisper_cuda.dll
Binary file not shown.
70 changes: 70 additions & 0 deletions Plugins/Windows/libwhisper_cuda.dll.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 17 additions & 6 deletions Runtime/Native/WhisperNative.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,22 @@ namespace Whisper.Native
/// </summary>
public static unsafe class WhisperNative
{

#if (UNITY_IOS || UNITY_ANDROID) && !UNITY_EDITOR
private const string LibraryName = "__Internal";
#else
#if UNITY_STANDALONE_WIN && WHISPER_CUDA
private const string LibraryName = "libwhisper_cuda";
#else
private const string LibraryName = "libwhisper";
#endif
#endif

[DllImport(LibraryName)]
public static extern whisper_context_ptr whisper_init_from_file(string path_model);
public static extern whisper_context_ptr whisper_init_from_file_with_params(string path_model, WhisperNativeContextParams @params);

[DllImport(LibraryName)]
public static extern whisper_context_ptr whisper_init_from_buffer(IntPtr buffer, UIntPtr buffer_size);
public static extern whisper_context_ptr whisper_init_from_buffer_with_params(IntPtr buffer, UIntPtr buffer_size, WhisperNativeContextParams @params);

[DllImport(LibraryName)]
public static extern int whisper_lang_max_id();
Expand All @@ -35,10 +40,16 @@ public static unsafe class WhisperNative

[DllImport(LibraryName)]
public static extern whisper_token whisper_token_eot(whisper_context_ptr ctx);


[DllImport(LibraryName)]
public static extern IntPtr whisper_print_system_info();

[DllImport(LibraryName)]
public static extern WhisperNativeParams whisper_full_default_params(WhisperSamplingStrategy strategy);


[DllImport(LibraryName)]
public static extern WhisperNativeContextParams whisper_context_default_params();

[DllImport(LibraryName)]
public static extern int whisper_full(whisper_context_ptr ctx, WhisperNativeParams param,
float* samples, int n_samples);
Expand Down
16 changes: 14 additions & 2 deletions Runtime/Native/WhisperNativeParams.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using System.Runtime.InteropServices;
// ReSharper disable InconsistentNaming
// ReSharper disable FieldCanBeMadeReadOnly.Local
Expand All @@ -8,7 +9,7 @@
using whisper_context_ptr = System.IntPtr;
using whisper_state_ptr = System.IntPtr;
using whisper_token = System.Int32;
using System;


namespace Whisper.Native
{
Expand Down Expand Up @@ -49,7 +50,18 @@ public struct WhisperNativeTokenData

public float vlen; // voice length of the token
}


/// <summary>
/// This is direct copy of C++ struct.
/// Do not change or add any fields without changing it in whisper.cpp.
/// Do not change it in runtime directly, use <see cref="WhisperContextParams"/>.
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public struct WhisperNativeContextParams
{
[MarshalAs(UnmanagedType.U1)] bool use_gpu;
};

/// <summary>
/// This is direct copy of C++ struct.
/// Do not change or add any fields without changing it in whisper.cpp.
Expand Down
32 changes: 32 additions & 0 deletions Runtime/WhisperParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,38 @@

namespace Whisper
{
/// <summary>
/// Wrapper of native C++ whisper context params.
/// Used during whisper model initialization.
/// </summary>
public class WhisperContextParams
{
private WhisperNativeContextParams _param;

/// <summary>
/// Native C++ struct parameters.
/// Do not change it in runtime directly, use setters.
/// </summary>
public WhisperNativeContextParams NativeParams => _param;

private WhisperContextParams(WhisperNativeContextParams param)
{
_param = param;
}

/// <summary>
/// Create a new default Whisper Context parameters.
/// </summary>
public static WhisperContextParams GetDefaultParams()
{
LogUtils.Verbose($"Requesting default Whisper Context params...");
var nativeParams = WhisperNative.whisper_context_default_params();
LogUtils.Verbose("Default Whisper Context params generated!");

return new WhisperContextParams(nativeParams);
}
}

/// <summary>
/// Wrapper of native C++ whisper parameters.
/// Use it to safely change inference parameters.
Expand Down
Loading

0 comments on commit d9b53a7

Please sign in to comment.