Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTOR-7194 Deferred session fetching for public endpoints #4609

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.server.sessions.*
import io.ktor.server.testing.*
import kotlinx.serialization.*
import kotlin.test.*
import kotlinx.serialization.Serializable
import kotlin.test.Test
import kotlin.test.assertEquals

class SessionAuthTest {
open class SessionAuthTest {
@Test
fun testSessionOnly() = testApplication {
install(Sessions) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.tests.auth

import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.auth.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.server.sessions.*
import io.ktor.server.sessions.serialization.*
import io.ktor.server.testing.*
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import kotlin.test.*

class SessionAuthDeferredTest : SessionAuthTest() {

@BeforeTest
fun setProperty() {
System.setProperty("io.ktor.server.sessions.deferred", "true")
}

@AfterTest
fun clearProperty() {
System.clearProperty("io.ktor.server.sessions.deferred")
}

@Test
fun sessionIgnoredForNonPublicEndpoints() = testApplication {
val brokenStorage = object : SessionStorage {
override suspend fun write(id: String, value: String) = Unit
override suspend fun invalidate(id: String) = error("invalidate called")
override suspend fun read(id: String): String = error("read called")
}
application {
install(Sessions) {
cookie<MySession>("S", storage = brokenStorage) {
serializer = KotlinxSessionSerializer(Json.Default)
}
}
install(Authentication.Companion) {
session<MySession> {
validate { it }
}
}
routing {
authenticate {
get("/authenticated") {
call.respondText("Secret info")
}
}
post("/session") {
call.sessions.set(MySession(1))
call.respondText("OK")
}
get("/public") {
call.respondText("Public info")
}
}
}
val withCookie: HttpRequestBuilder.() -> Unit = {
header("Cookie", "S=${defaultSessionSerializer<MySession>().serialize(MySession(1))}")
}

assertEquals(HttpStatusCode.Companion.OK, client.post("/session").status)
assertEquals(HttpStatusCode.Companion.OK, client.get("/public", withCookie).status)
assertFailsWith<IllegalStateException> {
client.get("/authenticated", withCookie).status
}
}

@Serializable
data class MySession(val id: Int)
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ public interface CurrentSession {
public fun findName(type: KClass<*>): String
}

/**
* Extends [CurrentSession] with a call to include session data in the server response.
*/
internal interface StatefulSession : CurrentSession {

/**
* Iterates over session data items and writes them to the application call.
* The session cannot be modified after this is called.
* This is called after the session data is sent to the response.
*/
suspend fun sendSessionData(call: ApplicationCall, onEach: (String) -> Unit = {})
}

/**
* Sets a session instance with the type [T].
* @throws IllegalStateException if no session provider is registered for the type [T]
Expand Down Expand Up @@ -99,11 +112,15 @@ public inline fun <reified T : Any> CurrentSession.getOrSet(name: String = findN

internal data class SessionData(
val providerData: Map<String, SessionProviderData<*>>
) : CurrentSession {
) : StatefulSession {

private var committed = false

internal fun commit() {
override suspend fun sendSessionData(call: ApplicationCall, onEach: (String) -> Unit) {
providerData.values.forEach { data ->
onEach(data.provider.name)
data.sendSessionData(call)
}
committed = true
}

Expand Down Expand Up @@ -175,7 +192,7 @@ internal data class SessionProviderData<S : Any>(
val provider: SessionProvider<S>
)

internal val SessionDataKey = AttributeKey<SessionData>("SessionKey")
internal val SessionDataKey = AttributeKey<StatefulSession>("SessionKey")

private fun ApplicationCall.reportMissingSession(): Nothing {
application.plugin(Sessions) // ensure the plugin is installed
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.server.sessions

import io.ktor.server.application.ApplicationCall

internal const val SESSIONS_DEFERRED_FLAG = "io.ktor.server.sessions.deferred"

internal expect fun isDeferredSessionsEnabled(): Boolean

/**
* Creates a lazy loading session from the given providers.
*/
internal expect fun createDeferredSession(call: ApplicationCall, providers: List<SessionProvider<*>>): StatefulSession
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package io.ktor.server.sessions

import io.ktor.server.application.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.util.*
import io.ktor.util.logging.*

Expand All @@ -27,24 +26,23 @@ internal val LOGGER = KtorSimpleLogger("io.ktor.server.sessions.Sessions")
*/
public val Sessions: RouteScopedPlugin<SessionsConfig> = createRouteScopedPlugin("Sessions", ::SessionsConfig) {
val providers = pluginConfig.providers.toList()
val sessionSupplier: suspend (ApplicationCall, List<SessionProvider<*>>) -> StatefulSession =
if (isDeferredSessionsEnabled()) {
::createDeferredSession
} else {
::createSession
}

application.attributes.put(SessionProvidersKey, providers)

onCall { call ->
// For each call, call each provider and retrieve session data if needed.
// Capture data in the attribute's value
val providerData = providers.associateBy({ it.name }) {
it.receiveSessionData(call)
}

if (providerData.isEmpty()) {
LOGGER.trace("No sessions found for ${call.request.uri}")
if (providers.isEmpty()) {
LOGGER.trace { "No sessions found for ${call.request.uri}" }
} else {
val sessions = providerData.keys.joinToString()
LOGGER.trace("Sessions found for ${call.request.uri}: $sessions")
val sessions = providers.joinToString { it.name }
LOGGER.trace { "Sessions found for ${call.request.uri}: $sessions" }
bjhham marked this conversation as resolved.
Show resolved Hide resolved
}
val sessionData = SessionData(providerData)
call.attributes.put(SessionDataKey, sessionData)
call.attributes.put(SessionDataKey, sessionSupplier(call, providers))
}

// When response is being sent, call each provider to update/remove session data
Expand All @@ -58,11 +56,18 @@ public val Sessions: RouteScopedPlugin<SessionsConfig> = createRouteScopedPlugin
*/
val sessionData = call.attributes.getOrNull(SessionDataKey) ?: return@on

sessionData.providerData.values.forEach { data ->
LOGGER.trace("Sending session data for ${call.request.uri}: ${data.provider.name}")
data.sendSessionData(call)
sessionData.sendSessionData(call) { provider ->
LOGGER.trace { "Sending session data for ${call.request.uri}: $provider" }
}
}
}

sessionData.commit()
private suspend fun createSession(call: ApplicationCall, providers: List<SessionProvider<*>>): StatefulSession {
// For each call, call each provider and retrieve session data if needed.
// Capture data in the attribute's value
val providerData = providers.associateBy({ it.name }) {
it.receiveSessionData(call)
}

return SessionData(providerData)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/*
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.server.sessions

import io.ktor.server.application.ApplicationCall

internal actual fun isDeferredSessionsEnabled(): Boolean = false

internal actual fun createDeferredSession(call: ApplicationCall, providers: List<SessionProvider<*>>): StatefulSession =
TODO("Deferred session retrieval is currently only available for JVM")
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/*
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.server.sessions

internal actual fun isDeferredSessionsEnabled(): Boolean =
System.getProperty(SESSIONS_DEFERRED_FLAG)?.toBoolean() == true
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
package io.ktor.server.sessions

import io.ktor.http.*
import io.ktor.server.sessions.serialization.*
import io.ktor.util.*
import kotlinx.serialization.*
import kotlinx.serialization.json.*
import java.lang.reflect.*
import java.math.*
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
import java.math.BigDecimal
import java.math.BigInteger
import java.util.*
import java.util.concurrent.*
import java.util.concurrent.ConcurrentHashMap
import kotlin.reflect.*
import kotlin.reflect.full.*
import kotlin.reflect.jvm.*
import kotlin.reflect.full.memberProperties
import kotlin.reflect.full.superclasses
import kotlin.reflect.jvm.javaType
import kotlin.reflect.jvm.jvmErasure

private const val TYPE_TOKEN_PARAMETER_NAME: String = "\$type"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.server.sessions

import io.ktor.server.application.ApplicationCall
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.runBlocking
import kotlin.coroutines.CoroutineContext
import kotlin.reflect.KClass

/**
* An implementation of [StatefulSession] that lazily references session providers to
* avoid unnecessary calls to session storage.
* All access to the deferred providers is done through blocking calls.
*/
internal class BlockingDeferredSessionData(
val callContext: CoroutineContext,
val providerData: Map<String, Deferred<SessionProviderData<*>>>,
) : StatefulSession {

private var committed = false

@OptIn(ExperimentalCoroutinesApi::class)
override suspend fun sendSessionData(call: ApplicationCall, onEach: (String) -> Unit) {
for (deferredProvider in providerData.values) {
// skip non-completed providers because they were not modified
if (!deferredProvider.isCompleted) continue
val data = deferredProvider.getCompleted()
onEach(data.provider.name)
data.sendSessionData(call)
}
committed = true
}

override fun findName(type: KClass<*>): String {
val entry = providerData.values.map {
it.awaitBlocking()
}.firstOrNull {
it.provider.type == type
} ?: throw IllegalArgumentException("Session data for type `$type` was not registered")

return entry.provider.name
}

override fun set(name: String, value: Any?) {
if (committed) {
throw TooLateSessionSetException()
}
val providerData =
providerData[name] ?: throw IllegalStateException("Session data for `$name` was not registered")
bjhham marked this conversation as resolved.
Show resolved Hide resolved
setTyped(providerData.awaitBlocking(), value)
}

@Suppress("UNCHECKED_CAST")
private fun <S : Any> setTyped(data: SessionProviderData<S>, value: Any?) {
if (value != null) {
data.provider.tracker.validate(value as S)
}
data.newValue = value as S
}

override fun get(name: String): Any? {
val providerDataDeferred =
providerData[name] ?: throw IllegalStateException("Session data for `$name` was not registered")
val providerData = providerDataDeferred.awaitBlocking()
return providerData.newValue ?: providerData.oldValue
}

override fun clear(name: String) {
val providerDataDeferred =
providerData[name] ?: throw IllegalStateException("Session data for `$name` was not registered")
val providerData = providerDataDeferred.awaitBlocking()
providerData.oldValue = null
providerData.newValue = null
}

private fun Deferred<SessionProviderData<*>>.awaitBlocking() =
runBlocking(callContext) { await() }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.server.sessions

import io.ktor.server.application.ApplicationCall
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.async

internal actual fun createDeferredSession(call: ApplicationCall, providers: List<SessionProvider<*>>): StatefulSession =
BlockingDeferredSessionData(
call.coroutineContext,
providers.associateBy({ it.name }) {
CoroutineScope(call.coroutineContext).async(start = CoroutineStart.LAZY) {
it.receiveSessionData(call)
}
}
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/*
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.server.sessions
import kotlinx.cinterop.ExperimentalForeignApi
import kotlinx.cinterop.toKString
import platform.posix.*

@OptIn(ExperimentalForeignApi::class)
internal actual fun isDeferredSessionsEnabled(): Boolean =
getenv(SESSIONS_DEFERRED_FLAG)?.toKString()?.toBoolean() == true
Loading