-
-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Swift 6 Complete Concurrency Checking #85
base: main
Are you sure you want to change the base?
Changes from all commits
258bca8
aeeadd8
f358d3e
ec16833
a8977fa
6618578
451ea40
ae6eacd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,14 +12,7 @@ import SpeziChat | |
|
||
|
||
extension LLMFogSession { | ||
private static let modelNotFoundRegex: Regex = { | ||
guard let regex = try? Regex("model '([\\w:]+)' not found, try pulling it first") else { | ||
preconditionFailure("SpeziLLMFog: Error Regex could not be parsed") | ||
} | ||
|
||
return regex | ||
}() | ||
|
||
private static let modelNotFoundRegex = "model '([\\w:]+)' not found, try pulling it first" | ||
|
||
/// Based on the input prompt, generate the output via some OpenAI API, e.g., Ollama. | ||
/// | ||
|
@@ -61,7 +54,7 @@ extension LLMFogSession { | |
} | ||
} catch let error as APIErrorResponse { | ||
// Sadly, there's no better way to check the error messages as there aren't any Ollama error codes as with the OpenAI API | ||
if error.error.message.contains(Self.modelNotFoundRegex) { | ||
if error.error.message.range(of: Self.modelNotFoundRegex, options: .regularExpression) != nil { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice simplification! |
||
Self.logger.error("SpeziLLMFog: LLM model type could not be accessed on fog node - \(error.error.message)") | ||
await finishGenerationWithError(LLMFogError.modelAccessError(error), on: continuation) | ||
} else if error.error.code == "401" || error.error.code == "403" { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ import SpeziViews | |
/// is of type ``LLMLocalDownloadManager/DownloadState``, containing states such as ``LLMLocalDownloadManager/DownloadState/downloading(progress:)`` | ||
/// which includes the progress of the download or ``LLMLocalDownloadManager/DownloadState/downloaded(storageUrl:)`` which indicates that the download has finished. | ||
@Observable | ||
public final class LLMLocalDownloadManager: NSObject { | ||
public final class LLMLocalDownloadManager: NSObject, @unchecked Sendable { | ||
/// An enum containing all possible states of the ``LLMLocalDownloadManager``. | ||
public enum DownloadState: Equatable { | ||
case idle | ||
|
@@ -79,49 +79,47 @@ public final class LLMLocalDownloadManager: NSObject { | |
} | ||
|
||
/// Starts a `URLSessionDownloadTask` to download the specified model. | ||
@MainActor | ||
public func startDownload() async { | ||
if modelExist { | ||
Task { @MainActor in | ||
self.state = .downloaded | ||
} | ||
state = .downloaded | ||
return | ||
} | ||
|
||
await cancelDownload() | ||
downloadTask = Task(priority: .userInitiated) { | ||
do { | ||
try await downloadWithHub() | ||
await MainActor.run { | ||
self.state = .downloaded | ||
} | ||
state = .downloaded | ||
} catch { | ||
await MainActor.run { | ||
self.state = .error( | ||
AnyLocalizedError( | ||
error: error, | ||
defaultErrorDescription: LocalizedStringResource("LLM_DOWNLOAD_FAILED_ERROR", bundle: .atURL(from: .module)) | ||
state = .error( | ||
AnyLocalizedError( | ||
error: error, | ||
defaultErrorDescription: LocalizedStringResource( | ||
"LLM_DOWNLOAD_FAILED_ERROR", | ||
bundle: .atURL(from: .module) | ||
) | ||
) | ||
} | ||
) | ||
} | ||
} | ||
} | ||
|
||
/// Cancels the download of a specified model via a `URLSessionDownloadTask`. | ||
@MainActor | ||
public func cancelDownload() async { | ||
downloadTask?.cancel() | ||
await MainActor.run { | ||
self.state = .idle | ||
} | ||
state = .idle | ||
} | ||
|
||
@MainActor | ||
|
||
private func downloadWithHub() async throws { | ||
let repo = Hub.Repo(id: model.hubID) | ||
let modelFiles = ["*.safetensors", "config.json"] | ||
|
||
try await HubApi.shared.snapshot(from: repo, matching: modelFiles) { progress in | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is forcing me to make the whole class There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah we don't want to mark the manager Did you try something like (rough sketch) try await HubApi.shared.snapshot(from: repo, matching: modelFiles) { progress in
Task { @MainActor [mutate = self.mutate] in
mutate(progress)
}
}
}
@MainActor private func mutate(progress: Progress) {
self.state = .downloading(progress: progress)
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That could work; I agree that marking the complete class as sendable without any manual locking implementations might not be great so I would aim to avoid it if we can or mark everything with main actor (which is also not great) which could basically also guarantee that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As you mentioned, marking everything as |
||
self.state = .downloading(progress: progress) | ||
Task { @MainActor in | ||
self.state = .downloading(progress: progress) | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should try to unify this across all packages, some of them use
.git
at the end, some of them don't. I would suggesting adding.git
for all of them for now.