Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 94 additions & 49 deletions Sources/TerraFoundationModels/TerraTracedSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,28 @@ import OpenTelemetryApi

@available(macOS 26.0, iOS 26.0, *)
public final class TerraTracedSession: @unchecked Sendable {
public enum SessionConcurrencyError: Error, Sendable, Equatable {
case concurrentOperationNotAllowed
}

private actor RequestGate {
private var inFlight = false

func enter() throws {
if inFlight {
throw SessionConcurrencyError.concurrentOperationNotAllowed
}
inFlight = true
}

func leave() {
inFlight = false
}
}

private let session: LanguageModelSession
public let modelIdentifier: String
private let requestGate = RequestGate()

public init(
model: SystemLanguageModel = .default,
Expand All @@ -23,18 +43,21 @@ public final class TerraTracedSession: @unchecked Sendable {

/// Respond to a prompt with auto-tracing.
public func respond(to prompt: String, promptCapture: Terra.CaptureIntent = .default) async throws -> String {
let request = Terra.InferenceRequest(
model: modelIdentifier,
prompt: prompt,
promptCapture: promptCapture
)
return try await Terra.withInferenceSpan(request) { scope in
scope.setAttributes([
Terra.Keys.Terra.runtime: .string("foundation_models"),
Terra.Keys.Terra.autoInstrumented: .bool(true)
])
let response = try await session.respond(to: prompt)
return response.content
try await withExclusiveSessionAccess {
let request = Terra.InferenceRequest(
model: modelIdentifier,
prompt: prompt,
promptCapture: promptCapture
)
let output = try await Terra.withInferenceSpan(request) { scope in
scope.setAttributes([
Terra.Keys.Terra.runtime: .string("foundation_models"),
Terra.Keys.Terra.autoInstrumented: .bool(true)
])
let response = try await session.respond(to: prompt)
return response.content
}
return output
}
}

Expand All @@ -44,18 +67,21 @@ public final class TerraTracedSession: @unchecked Sendable {
generating type: T.Type,
promptCapture: Terra.CaptureIntent = .default
) async throws -> T {
let request = Terra.InferenceRequest(
model: modelIdentifier,
prompt: prompt,
promptCapture: promptCapture
)
return try await Terra.withInferenceSpan(request) { scope in
scope.setAttributes([
Terra.Keys.Terra.runtime: .string("foundation_models"),
Terra.Keys.Terra.autoInstrumented: .bool(true),
"terra.foundation_models.response_type": .string(String(describing: T.self))
])
return try await session.respond(to: prompt, generating: type).content
try await withExclusiveSessionAccess {
let request = Terra.InferenceRequest(
model: modelIdentifier,
prompt: prompt,
promptCapture: promptCapture
)
let output: T = try await Terra.withInferenceSpan(request) { scope in
scope.setAttributes([
Terra.Keys.Terra.runtime: .string("foundation_models"),
Terra.Keys.Terra.autoInstrumented: .bool(true),
"terra.foundation_models.response_type": .string(String(describing: T.self))
])
return try await session.respond(to: prompt, generating: type).content
}
return output
}
}

Expand All @@ -66,26 +92,33 @@ public final class TerraTracedSession: @unchecked Sendable {

return AsyncThrowingStream { continuation in
let task = Task { [weak self] in
let request = Terra.InferenceRequest(
model: modelIdentifier,
prompt: prompt,
promptCapture: promptCapture,
stream: true
)
guard let self else {
continuation.finish(throwing: CancellationError())
return
Comment on lines +95 to +97

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Retain session during stream task startup

The new guard let self causes streamResponse to fail with CancellationError if the TerraTracedSession instance is released before the Task starts, even though modelIdentifier and session are already captured and sufficient to run the stream. This introduces a regression for call patterns that don't keep an extra strong reference (for example, consuming a stream returned from a temporary session), leading to nondeterministic early cancellation.

Useful? React with 👍 / 👎.

}

do {
try await Terra.withStreamingInferenceSpan(request) { streamScope in
streamScope.setAttributes([
Terra.Keys.Terra.runtime: .string("foundation_models"),
Terra.Keys.Terra.autoInstrumented: .bool(true)
])
let stream = session.streamResponse(to: prompt)
for try await partial in stream {
try Task.checkCancellation()
streamScope.recordChunk()
if let explicitCount = self?.explicitOutputTokenCount(from: partial) {
streamScope.recordOutputTokenCount(explicitCount)
try await self.withExclusiveSessionAccess {
let request = Terra.InferenceRequest(
model: modelIdentifier,
prompt: prompt,
promptCapture: promptCapture,
stream: true
)
try await Terra.withStreamingInferenceSpan(request) { streamScope in
streamScope.setAttributes([
Terra.Keys.Terra.runtime: .string("foundation_models"),
Terra.Keys.Terra.autoInstrumented: .bool(true)
])
let stream = session.streamResponse(to: prompt)
for try await partial in stream {
try Task.checkCancellation()
streamScope.recordChunk()
if let explicitCount = self.explicitOutputTokenCount(from: partial) {
streamScope.recordOutputTokenCount(explicitCount)
}
continuation.yield(partial.content)
}
continuation.yield(partial.content)
}
}
continuation.finish()
Expand All @@ -106,22 +139,34 @@ public final class TerraTracedSession: @unchecked Sendable {
"tokensGenerated",
]

/// Tracks whether we've already probed for a token count field and found none.
private var tokenCountFieldChecked = false

private func explicitOutputTokenCount(from partial: Any) -> Int? {
// After first nil result, skip Mirror reflection entirely
if tokenCountFieldChecked { return nil }

for child in Mirror(reflecting: partial).children {
guard let label = child.label, Self.supportedTokenCountNames.contains(label) else { continue }
if let intValue = child.value as? Int, intValue >= 0 {
return intValue
}
}
tokenCountFieldChecked = true
return nil
}

private func withExclusiveSessionAccess<T>(_ operation: () async throws -> T) async throws -> T {
try await requestGate.enter()
do {
let value = try await operation()
await requestGate.leave()
return value
} catch {
await requestGate.leave()
throw error
}
}

func _holdExclusiveAccessForTesting(nanoseconds: UInt64) async throws {
_ = try await withExclusiveSessionAccess {
try await Task.sleep(nanoseconds: nanoseconds)
return ()
}
}
}

#else
Expand Down
81 changes: 49 additions & 32 deletions Sources/TerraTraceKit/OTLPHTTPServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,17 @@ public final class OTLPHTTPServer: @unchecked Sendable {
private static let maxActiveConnections = 64

private let queue = DispatchQueue(label: "terra.trace.otlp.httpserver")
private let queueKey = DispatchSpecificKey<UInt8>()
private let queueIdentity: UInt8 = 1
private var listener: NWListener?
private var activeConnections: [ObjectIdentifier: NWConnection] = [:]
private var readTimeoutTimers: [ObjectIdentifier: DispatchSourceTimer] = [:]
private var decodeTasks: [ObjectIdentifier: Task<Void, Never>] = [:]

public var port: UInt16 {
listener?.port?.rawValue ?? configuredPort
withQueueSync {
listener?.port?.rawValue ?? configuredPort
}
}

public init(
Expand All @@ -66,55 +70,52 @@ public final class OTLPHTTPServer: @unchecked Sendable {
self.traceStore = traceStore
self.limits = limits
self.onSpans = onSpans
queue.setSpecific(key: queueKey, value: queueIdentity)
}

public func start() throws {
guard listener == nil else { return }

let parameters = NWParameters.tcp
let listener: NWListener
if configuredPort == 0 {
listener = try NWListener(using: parameters)
} else if let port = NWEndpoint.Port(rawValue: configuredPort) {
if shouldBindToHost(host) {
parameters.requiredLocalEndpoint = .hostPort(host: NWEndpoint.Host(host), port: port)
try withQueueSync {
guard listener == nil else { return }

let parameters = NWParameters.tcp
let listener: NWListener
if configuredPort == 0 {
listener = try NWListener(using: parameters)
} else if let port = NWEndpoint.Port(rawValue: configuredPort) {
if shouldBindToHost(host) {
parameters.requiredLocalEndpoint = .hostPort(host: NWEndpoint.Host(host), port: port)
listener = try NWListener(using: parameters)
} else {
listener = try NWListener(using: parameters, on: port)
}
} else {
listener = try NWListener(using: parameters, on: port)
throw NSError(domain: "OTLPHTTPServer", code: 1, userInfo: [NSLocalizedDescriptionKey: "Invalid port"])
}
} else {
throw NSError(domain: "OTLPHTTPServer", code: 1, userInfo: [NSLocalizedDescriptionKey: "Invalid port"])
}

listener.stateUpdateHandler = { [weak self] (state: NWListener.State) in
if case .failed = state {
self?.stop()
listener.stateUpdateHandler = { [weak self] (state: NWListener.State) in
if case .failed = state {
self?.stop()
}
}
}

listener.newConnectionHandler = { [weak self] connection in
self?.handle(connection)
}
listener.newConnectionHandler = { [weak self] connection in
self?.handle(connection)
}

self.listener = listener
listener.start(queue: queue)
self.listener = listener
listener.start(queue: queue)
}
}

public func stop() {
queue.async {
self.listener?.cancel()
self.listener = nil
for id in Array(self.activeConnections.keys) {
self.cleanupConnection(id: id)
}
self.stopLocked()
}
}

deinit {
listener?.cancel()
listener = nil
for id in Array(activeConnections.keys) {
cleanupConnection(id: id)
withQueueSync {
self.stopLocked()
}
}

Expand Down Expand Up @@ -410,6 +411,22 @@ public final class OTLPHTTPServer: @unchecked Sendable {
return !lowered.isEmpty && lowered != "0.0.0.0" && lowered != "::"
}

private func stopLocked() {
listener?.cancel()
listener = nil
for id in Array(activeConnections.keys) {
cleanupConnection(id: id)
}
}

@discardableResult
private func withQueueSync<T>(_ body: () throws -> T) rethrows -> T {
if DispatchQueue.getSpecific(key: queueKey) == queueIdentity {
return try body()
}
return try queue.sync(execute: body)
}

private func parseRequestHead(_ data: Data) -> Result<HTTPRequestHead, HTTPParseError> {
guard let headerString = String(data: data, encoding: .utf8) else {
return .failure(.badRequest("Invalid header encoding"))
Expand Down
26 changes: 14 additions & 12 deletions Sources/TerraTraceKit/TraceFileReader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,25 @@ public struct TraceFileReader {
}

do {
let attributes = try fileManager.attributesOfItem(atPath: url.path)
if let sizeValue = attributes[.size] as? NSNumber {
let size = sizeValue.intValue
if size > maxFileSizeBytes {
throw TraceFileError.fileTooLarge(url, actualBytes: size, maxBytes: maxFileSizeBytes)
}
let handle = try FileHandle(forReadingFrom: url)
defer { try? handle.close() }

let initialSize = try handle.seekToEnd()
if initialSize > UInt64(maxFileSizeBytes) {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Validate max size before unsigned/file-read arithmetic

maxFileSizeBytes is a public initializer input, but this code now performs UInt64(maxFileSizeBytes) and maxFileSizeBytes + 1 without validation; values like -1 or Int.max will trap at runtime instead of throwing a TraceFileError. This turns malformed configuration into a process crash in production paths that instantiate TraceFileReader with external limits.

Useful? React with 👍 / 👎.

let actualBytes = Int(initialSize > UInt64(Int.max) ? UInt64(Int.max) : initialSize)
throw TraceFileError.fileTooLarge(url, actualBytes: actualBytes, maxBytes: maxFileSizeBytes)
}
try handle.seek(toOffset: 0)

let data = try handle.read(upToCount: maxFileSizeBytes + 1) ?? Data()
if data.count > maxFileSizeBytes {
throw TraceFileError.fileTooLarge(url, actualBytes: data.count, maxBytes: maxFileSizeBytes)
}
return data
} catch let fileError as TraceFileError {
throw fileError
} catch {
throw TraceFileError.readFailed(url)
}

do {
return try Data(contentsOf: url)
} catch {
throw TraceFileError.readFailed(url)
}
}
}
26 changes: 25 additions & 1 deletion Tests/TerraFoundationModelsTests/TerraTracedSessionTests.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import Testing

#if canImport(FoundationModels)
import TerraFoundationModels
@testable import TerraFoundationModels
import TerraCore

@available(macOS 26.0, iOS 26.0, *)
Expand All @@ -18,6 +18,30 @@ func tracedSessionInitializesWithCustomIdentifier() {
#expect(session.modelIdentifier == "apple/custom-model")
}

@available(macOS 26.0, iOS 26.0, *)
@Test("TerraTracedSession rejects concurrent in-flight operations")
func tracedSessionRejectsConcurrentOperations() async throws {
let session = TerraTracedSession()

let holdingTask = Task {
try await session._holdExclusiveAccessForTesting(nanoseconds: 300_000_000)
}
defer { holdingTask.cancel() }

try await Task.sleep(nanoseconds: 50_000_000)

do {
try await session._holdExclusiveAccessForTesting(nanoseconds: 10_000_000)
Issue.record("Expected concurrentOperationNotAllowed error")
} catch let error as TerraTracedSession.SessionConcurrencyError {
#expect(error == .concurrentOperationNotAllowed)
} catch {
Issue.record("Unexpected error type: \(error)")
}

try await holdingTask.value
}

#else

// FoundationModels is not available on this platform or SDK.
Expand Down
Loading