Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swift 6 Complete Concurrency Checking #85

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Package.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// swift-tools-version:5.9
// swift-tools-version:6.0

//
// This source file is part of the Stanford Spezi open source project
Expand All @@ -10,7 +10,6 @@

import PackageDescription


let package = Package(
name: "SpeziLLM",
defaultLocalization: "en",
Expand All @@ -35,7 +34,7 @@ let package = Package(
.package(url: "https://github.com/StanfordSpezi/SpeziFoundation", from: "2.0.0"),
.package(url: "https://github.com/StanfordSpezi/SpeziStorage", from: "1.0.2"),
.package(url: "https://github.com/StanfordSpezi/SpeziOnboarding", from: "1.1.1"),
.package(url: "https://github.com/StanfordSpezi/SpeziChat", .upToNextMinor(from: "0.2.1")),
.package(url: "https://github.com/StanfordSpezi/SpeziChat", .upToNextMinor(from: "0.2.2")),
.package(url: "https://github.com/StanfordSpezi/SpeziViews", from: "1.3.1")
],
targets: [
Expand Down
2 changes: 1 addition & 1 deletion Sources/SpeziLLM/LLMPlatformState.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
/// Describes the current state of the ``LLMPlatform`` which is responsible for sending ``LLMSchema``s to execution via ``LLMSession``s.
///
/// The ``LLMPlatformState`` is quite minimal with only ``LLMPlatformState/idle`` and ``LLMPlatformState/processing`` states.
public enum LLMPlatformState {
public enum LLMPlatformState: Sendable {
/// Indicates that the ``LLMPlatform`` is currently idle and doesn't execute any ``LLMSession``s.
case idle
/// Indicates that the ``LLMPlatform`` is currently processing and executing ``LLMSession``s.
Expand Down
2 changes: 1 addition & 1 deletion Sources/SpeziLLM/LLMSessionProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public struct _LLMSessionProvider<Schema: LLMSchema>: DynamicProperty { // s
}

/// Creates a `Binding` to the ``LLMSession``that one can pass around. Useful for passing the ``LLMSession`` as a `Binding` to the ``LLMChatView``.
public var projectedValue: Binding<Schema.Platform.Session> {
@MainActor public var projectedValue: Binding<Schema.Platform.Session> {
jdisho marked this conversation as resolved.
Show resolved Hide resolved
Binding {
wrappedValue
} set: {
Expand Down
6 changes: 3 additions & 3 deletions Sources/SpeziLLM/Models/LLMContextEntity.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ import Foundation
/// A ``LLMContextEntity`` can be thought of as a single message entity within a ``LLMContext``
/// It consists of a ``LLMContextEntity/Role``, a unique identifier, a timestamp in the form of a `Date` as well as an `String`-based ``LLMContextEntity/content`` property which can contain Markdown-formatted text.
/// Furthermore, the ``LLMContextEntity/complete`` flag indicates if the current state of the ``LLMContextEntity`` is final and the content will not be updated anymore.
public struct LLMContextEntity: Codable, Equatable, Hashable, Identifiable {
public struct LLMContextEntity: Codable, Equatable, Hashable, Identifiable, Sendable {
/// Represents a tool call by the LLM, including its parameters
public struct ToolCall: Codable, Equatable, Hashable {
public struct ToolCall: Codable, Equatable, Hashable, Sendable {
/// The ID of the function call, uniquely identifying the specific function call and matching the response to it.
public let id: String
/// The name of the function call.
Expand All @@ -39,7 +39,7 @@ public struct LLMContextEntity: Codable, Equatable, Hashable, Identifiable {
}

/// Indicates which ``LLMContextEntity/Role`` is associated with a ``LLMContextEntity``.
public enum Role: Codable, Equatable, Hashable {
public enum Role: Codable, Equatable, Hashable, Sendable {
case user
case assistant(toolCalls: [ToolCall] = [])
case system
Expand Down
2 changes: 1 addition & 1 deletion Sources/SpeziLLM/Models/LLMState.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Foundation
/// Describes possible states that the ``LLMSession`` can be in.
///
/// Based on the ``LLMState``, `SpeziLLM` performs proper actions on the model as well as state management.
public enum LLMState: CustomStringConvertible, Equatable {
public enum LLMState: CustomStringConvertible, Equatable, Sendable {
/// The Spezi ``LLMSession`` is allocated, but the underlying model has not yet been initialized.
case uninitialized
/// The Spezi ``LLMSession`` is in the process of being initialized.
Expand Down
11 changes: 2 additions & 9 deletions Sources/SpeziLLMFog/LLMFogSession+Generation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,7 @@ import SpeziChat


extension LLMFogSession {
private static let modelNotFoundRegex: Regex = {
guard let regex = try? Regex("model '([\\w:]+)' not found, try pulling it first") else {
preconditionFailure("SpeziLLMFog: Error Regex could not be parsed")
}

return regex
}()

private static let modelNotFoundRegex = "model '([\\w:]+)' not found, try pulling it first"

/// Based on the input prompt, generate the output via some OpenAI API, e.g., Ollama.
///
Expand Down Expand Up @@ -61,7 +54,7 @@ extension LLMFogSession {
}
} catch let error as APIErrorResponse {
// Sadly, there's no better way to check the error messages as there aren't any Ollama error codes as with the OpenAI API
if error.error.message.contains(Self.modelNotFoundRegex) {
if error.error.message.range(of: Self.modelNotFoundRegex, options: .regularExpression) != nil {
Self.logger.error("SpeziLLMFog: LLM model type could not be accessed on fog node - \(error.error.message)")
await finishGenerationWithError(LLMFogError.modelAccessError(error), on: continuation)
} else if error.error.code == "401" || error.error.code == "403" {
Expand Down
36 changes: 17 additions & 19 deletions Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import SpeziViews
/// is of type ``LLMLocalDownloadManager/DownloadState``, containing states such as ``LLMLocalDownloadManager/DownloadState/downloading(progress:)``
/// which includes the progress of the download or ``LLMLocalDownloadManager/DownloadState/downloaded(storageUrl:)`` which indicates that the download has finished.
@Observable
public final class LLMLocalDownloadManager: NSObject {
public final class LLMLocalDownloadManager: NSObject, @unchecked Sendable {
/// An enum containing all possible states of the ``LLMLocalDownloadManager``.
public enum DownloadState: Equatable {
case idle
Expand Down Expand Up @@ -79,49 +79,47 @@ public final class LLMLocalDownloadManager: NSObject {
}

/// Starts a `URLSessionDownloadTask` to download the specified model.
@MainActor
public func startDownload() async {
if modelExist {
Task { @MainActor in
self.state = .downloaded
}
state = .downloaded
return
}

await cancelDownload()
downloadTask = Task(priority: .userInitiated) {
do {
try await downloadWithHub()
await MainActor.run {
self.state = .downloaded
}
state = .downloaded
} catch {
await MainActor.run {
self.state = .error(
AnyLocalizedError(
error: error,
defaultErrorDescription: LocalizedStringResource("LLM_DOWNLOAD_FAILED_ERROR", bundle: .atURL(from: .module))
state = .error(
AnyLocalizedError(
error: error,
defaultErrorDescription: LocalizedStringResource(
"LLM_DOWNLOAD_FAILED_ERROR",
bundle: .atURL(from: .module)
)
)
}
)
}
}
}

/// Cancels the download of a specified model via a `URLSessionDownloadTask`.
@MainActor
public func cancelDownload() async {
downloadTask?.cancel()
await MainActor.run {
self.state = .idle
}
state = .idle
}

@MainActor

private func downloadWithHub() async throws {
let repo = Hub.Repo(id: model.hubID)
let modelFiles = ["*.safetensors", "config.json"]

try await HubApi.shared.snapshot(from: repo, matching: modelFiles) { progress in
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is forcing me to make the whole class @unchecked Sendable, because we can't pass progress safely without potentially causing data races. Do you have another idea?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we don't want to mark the manager Sendable cause it isn't.
The issue is the double capture of self, once within the snapshot() callback and once within the Task.

Did you try something like (rough sketch)

try await HubApi.shared.snapshot(from: repo, matching: modelFiles) { progress in
        Task { @MainActor [mutate = self.mutate] in
            mutate(progress)
        }
    }
}

@MainActor private func mutate(progress: Progress) {
      self.state = .downloading(progress: progress)
}

self.state = .downloading(progress: progress)
Task { @MainActor in
self.state = .downloading(progress: progress)
}
}
}
}
Loading