Skip to content

Commit

Permalink
fix: add checks to prevent access to uninitialized natives
Browse files Browse the repository at this point in the history
  • Loading branch information
silenium-dev committed Oct 15, 2024
1 parent 5519e96 commit 888a785
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 56 deletions.
11 changes: 10 additions & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ tasks {
useJUnitPlatform()
}

register<Copy>("generateTemplates") {
val generateTemplates = register<Copy>("generateTemplates") {
from(templateSrc)
into(templateDst)
expand(templateProps)
Expand All @@ -84,6 +84,10 @@ tasks {
outputs.dir(templateDst)
}

withType<Jar> {
dependsOn(generateTemplates)
}

compileKotlin {
dependsOn("generateTemplates")
}
Expand All @@ -102,6 +106,11 @@ sourceSets.main {
}
}

java {
withSourcesJar()
withJavadocJar()
}

allprojects {
apply<MavenPublishPlugin>()
apply<BasePlugin>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ class VideoPlayer(hwdec: Boolean = false) : AutoCloseable {
}

override fun close() {
mpv.command("stop")
render?.close()
mpv.close()
}
Expand Down
181 changes: 129 additions & 52 deletions src/main/kotlin/dev/silenium/multimedia/core/mpv/MPV.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicLong
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.coroutines.EmptyCoroutineContext
import kotlin.coroutines.resume
import kotlin.reflect.KClass
Expand Down Expand Up @@ -45,13 +48,39 @@ class MPV : NativeCleanable, MPVAsyncListener {

private val propertyUpdates = MutableSharedFlow<Property<*>>()

@OptIn(ExperimentalContracts::class)
private inline fun <R : Any> guard(other: R? = null, block: () -> Result<R>): Result<R> {
contract { callsInPlace(block, InvocationKind.AT_MOST_ONCE) }
if (!initialized.get()) {
if (other != null) {
return Result.success(other)
}
return Result.failure(IllegalStateException("MPV is not initialized"))
}
return block()
}

@OptIn(ExperimentalContracts::class)
@JvmName("guardNonNull")
private inline fun <R> guardNonNull(other: R? = null, block: () -> Result<R?>): Result<R?> {
contract { callsInPlace(block, InvocationKind.AT_MOST_ONCE) }
if (!initialized.get()) {
if (other != null) {
return Result.success(other)
}
return Result.failure(IllegalStateException("MPV is not initialized"))
}
return block()
}

override val nativePointer = NativePointer(createN().getOrThrow()) {
callback?.let(::unsetCallbackN)
listener.close()
destroyN(it)
}

fun setOption(name: String, value: String) {
if (nativePointer.closed) {
error("MPV is closed")
}
if (initialized.get()) {
logger.warn("Cannot set option after initialization, ignoring")
return
Expand All @@ -76,15 +105,17 @@ class MPV : NativeCleanable, MPVAsyncListener {
name: String,
value: T,
fn: (Long, String, T, Long) -> Result<Unit>,
): Result<Unit> = suspendCancellableCoroutine { continuation ->
val subscriptionId = propertySetCallbackId.getAndIncrement()
propertySetCallbacks[subscriptionId] = { result ->
continuation.resume(result)
}
fn(nativePointer.address, name, value, subscriptionId).onFailure {
propertySetCallbacks.remove(subscriptionId)
logger.error("Failed to set property $name", it)
continuation.resume(Result.failure(it))
): Result<Unit> = guard(Unit) {
suspendCancellableCoroutine { continuation ->
val subscriptionId = propertySetCallbackId.getAndIncrement()
propertySetCallbacks[subscriptionId] = { result ->
continuation.resume(result)
}
fn(nativePointer.address, name, value, subscriptionId).onFailure {
propertySetCallbacks.remove(subscriptionId)
logger.error("Failed to set property $name", it)
continuation.resume(Result.failure(it))
}
}
}

Expand All @@ -102,18 +133,20 @@ class MPV : NativeCleanable, MPVAsyncListener {
name: String,
type: KClass<T>,
fn: (Long, String, Long) -> Result<Unit>,
): Result<T?> = suspendCancellableCoroutine { continuation ->
val subscriptionId = propertyGetCallbackId.getAndIncrement()
propertyGetCallbacks[subscriptionId] = { result ->
continuation.resume(result.map {
logger.debug("Got property {}: {}", name, it)
it?.let(type::cast)
})
}
fn(nativePointer.address, name, subscriptionId).onFailure {
propertyGetCallbacks.remove(subscriptionId)
logger.error("Failed to get property $name", it)
continuation.resume(Result.failure(it))
): Result<T?> = guardNonNull<T>(null) {
suspendCancellableCoroutine { continuation ->
val subscriptionId = propertyGetCallbackId.getAndIncrement()
propertyGetCallbacks[subscriptionId] = { result ->
continuation.resume(result.map {
logger.debug("Got property {}: {}", name, it)
it?.let(type::cast)
})
}
fn(nativePointer.address, name, subscriptionId).onFailure {
propertyGetCallbacks.remove(subscriptionId)
logger.error("Failed to get property $name", it)
continuation.resume(Result.failure(it))
}
}
}

Expand All @@ -127,23 +160,24 @@ class MPV : NativeCleanable, MPVAsyncListener {
fun getPropertyDouble(name: String) = getPropertyDoubleN(nativePointer.address, name)
fun getPropertyFlag(name: String) = getPropertyFlagN(nativePointer.address, name)

private fun subscribe(name: String, type: KClass<*>, fn: (Long, String, Long) -> Result<Unit>): Result<Unit> {
if (propertySubscriptions.containsKey(name)) {
logger.debug("Property $name is already being observed")
return Result.success(Unit)
private fun subscribe(name: String, type: KClass<*>, fn: (Long, String, Long) -> Result<Unit>): Result<Unit> =
guard(Unit) {
if (propertySubscriptions.containsKey(name)) {
logger.debug("Property $name is already being observed")
return Result.success(Unit)
}
logger.debug("Observing property $name")
val subscriptionId = subscriptionId.getAndIncrement()
return fn(nativePointer.address, name, subscriptionId)
.map { propertySubscriptions[name] = subscriptionId to type }
}
logger.debug("Observing property $name")
val subscriptionId = subscriptionId.getAndIncrement()
return fn(nativePointer.address, name, subscriptionId)
.map { propertySubscriptions[name] = subscriptionId to type }
}

fun observePropertyString(name: String) = subscribe(name, String::class, ::observePropertyStringN)
fun observePropertyLong(name: String) = subscribe(name, Long::class, ::observePropertyLongN)
fun observePropertyDouble(name: String) = subscribe(name, Double::class, ::observePropertyDoubleN)
fun observePropertyFlag(name: String) = subscribe(name, Boolean::class, ::observePropertyFlagN)

fun unobserveProperty(name: String): Result<Unit> {
fun unobserveProperty(name: String): Result<Unit> = guard(Unit) {
val (id, _) = propertySubscriptions[name] ?: run {
logger.debug("Property $name is not being observed")
return Result.success(Unit)
Expand All @@ -160,17 +194,27 @@ class MPV : NativeCleanable, MPVAsyncListener {
override fun command(subscriptionCount: StateFlow<Int>): Flow<SharingCommand> {
return wrapped.command(subscriptionCount).onEach { command ->
when (command) {
SharingCommand.START -> subscribe(name).getOrThrow()
SharingCommand.START -> subscribe(name).getOrElse {
logger.error("Failed to subscribe to property $name", it)
return@onEach
}

SharingCommand.STOP,
SharingCommand.STOP_AND_RESET_REPLAY_CACHE,
-> unsubscribe(name).getOrThrow()
-> unsubscribe(name).getOrElse {
logger.error("Failed to unsubscribe from property $name", it)
return@onEach
}
}
}
}
}

suspend fun propertyFlowString(name: String): StateFlow<String?> {
val initialValue = getPropertyStringAsync(name).getOrThrow()
val initialValue = getPropertyStringAsync(name).getOrElse {
logger.error("Failed to get initial value for property $name", it)
null
}
val flow = propertyUpdates.filter { it.name == name }.filterIsInstance<StringProperty>().map { it.value }
return flow.stateIn(
CoroutineScope(EmptyCoroutineContext),
Expand All @@ -180,7 +224,10 @@ class MPV : NativeCleanable, MPVAsyncListener {
}

suspend fun propertyFlowLong(name: String): StateFlow<Long?> {
val initialValue = getPropertyLongAsync(name).getOrThrow()
val initialValue = getPropertyLongAsync(name).getOrElse {
logger.error("Failed to get initial value for property $name", it)
null
}
val flow = propertyUpdates.filter { it.name == name }.filterIsInstance<LongProperty>().map { it.value }
return flow.stateIn(
CoroutineScope(EmptyCoroutineContext),
Expand All @@ -190,7 +237,10 @@ class MPV : NativeCleanable, MPVAsyncListener {
}

suspend fun propertyFlowDouble(name: String): StateFlow<Double?> {
val initialValue = getPropertyDoubleAsync(name).getOrThrow()
val initialValue = getPropertyDoubleAsync(name).getOrElse {
logger.error("Failed to get initial value for property $name", it)
null
}
val flow = propertyUpdates.filter { it.name == name }.filterIsInstance<DoubleProperty>().map { it.value }
return flow.stateIn(
CoroutineScope(EmptyCoroutineContext),
Expand All @@ -200,7 +250,10 @@ class MPV : NativeCleanable, MPVAsyncListener {
}

suspend fun propertyFlowFlag(name: String): StateFlow<Boolean?> {
val initialValue = getPropertyFlagAsync(name).getOrThrow()
val initialValue = getPropertyFlagAsync(name).getOrElse {
logger.error("Failed to get initial value for property $name", it)
null
}
val flow = propertyUpdates.filter { it.name == name }.filterIsInstance<FlagProperty>().map { it.value }
return flow.stateIn(
CoroutineScope(EmptyCoroutineContext),
Expand Down Expand Up @@ -251,20 +304,27 @@ class MPV : NativeCleanable, MPVAsyncListener {
else -> error("Unsupported property type: ${value::class}")
}

fun command(command: Array<String>) = commandN(nativePointer.address, command)
fun command(command: String) = commandStringN(nativePointer.address, command)
fun command(command: Array<String>) = guard(Unit) {
commandN(nativePointer.address, command)
}

fun command(command: String) = guard(Unit) {
commandStringN(nativePointer.address, command)
}

@JvmName("commandAsyncVararg")
suspend fun commandAsync(vararg command: String) = commandAsync(command.toList().toTypedArray())
suspend fun commandAsync(command: Array<String>): Result<Unit> = suspendCancellableCoroutine { continuation ->
val subscriptionId = commandReplyCallbackId.getAndIncrement()
commandReplyCallbacks[subscriptionId] = { result ->
continuation.resume(result)
}
commandAsyncN(nativePointer.address, command, subscriptionId).onFailure {
commandReplyCallbacks.remove(subscriptionId)
logger.error("Failed to execute command $command", it)
continuation.resume(Result.failure(it))
suspend fun commandAsync(command: Array<String>): Result<Unit> = guard(Unit) {
suspendCancellableCoroutine { continuation ->
val subscriptionId = commandReplyCallbackId.getAndIncrement()
commandReplyCallbacks[subscriptionId] = { result ->
continuation.resume(result)
}
commandAsyncN(nativePointer.address, command, subscriptionId).onFailure {
commandReplyCallbacks.remove(subscriptionId)
logger.error("Failed to execute command $command", it)
continuation.resume(Result.failure(it))
}
}
}

Expand Down Expand Up @@ -304,8 +364,20 @@ class MPV : NativeCleanable, MPVAsyncListener {
createRenderN(mpv.nativePointer.address, this, advancedControl).getOrThrow()
.asNativePointer(::destroyRenderN)

fun render(fbo: FBO): Result<Unit> {
return renderN(
@OptIn(ExperimentalContracts::class)
private fun <R> guard(other: R? = null, block: () -> Result<R>): Result<R> {
contract { callsInPlace(block, InvocationKind.AT_MOST_ONCE) }
if (nativePointer.closed) {
if (other != null) {
return Result.success(other)
}
return Result.failure(IllegalStateException("Render is closed"))
}
return block()
}

fun render(fbo: FBO): Result<Unit> = guard(Unit) {
renderN(
nativePointer.address,
fbo.id,
fbo.size.width,
Expand All @@ -325,6 +397,11 @@ class MPV : NativeCleanable, MPVAsyncListener {
}
}

override fun close() {
initialized.set(false)
super.close()
}

companion object {
private val logger = LoggerFactory.getLogger(MPV::class.java)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ import org.slf4j.LoggerFactory
import java.util.concurrent.atomic.AtomicBoolean

data class NativePointer(val address: Long, val clean: (Long) -> Unit) : AutoCloseable {
private val closed = AtomicBoolean(false)
private val _closed = AtomicBoolean(false)
val closed get() = _closed.get()
override fun close() {
if (address == 0L) {
logger.warn("Attempt to close NULL NativePointer")
return
}
if (closed.compareAndSet(false, true)) {
if (_closed.compareAndSet(false, true)) {
clean(address)
} else {
logger.warn("Attempt to close already closed NativePointer: $this")
Expand Down

0 comments on commit 888a785

Please sign in to comment.