Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove configurable weighting in DamerauLevenshteinDistance #4

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
jvm = "17"
agp = "8.5.0"
kotlin = "2.0.0"
nexus-publish = "2.0.0-rc-1"
nexus-publish = "2.0.0"
android-minSdk = "26"
android-compileSdk = "34"
compose = "1.6.11"
Expand All @@ -24,5 +24,6 @@ jmh-annprocess = { module = "org.openjdk.jmh:jmh-generator-annprocess", version.
[plugins]
android-library = { id = "com.android.library", version.ref = "agp" }
kotlin-multiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" }
kotlin-powerassert = { id = "org.jetbrains.kotlin.plugin.power-assert", version.ref = "kotlin" }
compose = { id = "org.jetbrains.compose", version.ref = "compose" }
compose-compiler = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" }
13 changes: 13 additions & 0 deletions library/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import org.jetbrains.kotlin.gradle.ExperimentalKotlinGradlePluginApi
import org.jetbrains.kotlin.gradle.targets.js.dsl.ExperimentalWasmDsl

plugins {
alias(libs.plugins.kotlin.multiplatform)
alias(libs.plugins.kotlin.powerassert)
alias(libs.plugins.android.library)
id("module.publication")
}
Expand Down Expand Up @@ -86,3 +88,14 @@ android {
targetCompatibility = JavaVersion.toVersion(libs.versions.jvm.get().toInt())
}
}

@OptIn(ExperimentalKotlinGradlePluginApi::class)
powerAssert {
functions = listOf(
"kotlin.assert",
"kotlin.test.assertTrue",
"kotlin.test.assertEquals",
"kotlin.test.assertNull"
)
includedSourceSets = listOf("commonMain", "jvmTest")
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,6 @@ data class SpellCheckSettings(
*/
val topK: Int = 10, // limits result to n entries

/**
* Damerau function variables Deletion weight: 1.20 ~ 1.40
*/
val deletionWeight: Double = 1.0,

/**
* Damerau function variables Insertion weight: 1.01
*/
val insertionWeight: Double = 1.0,

/**
* Damerau function variables Replace weight: 0.9f ~ 1.20
*/
val replaceWeight: Double = 1.0,

/**
* Damerau function variables Transposition weight: 0.7f ~ 1.05
*/
val transpositionWeight: Double = 1.0,

/**
* true if the spellchecker should lowercase terms
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.darkrockstudios.symspellkt.common
package com.darkrockstudios.symspellkt.common.stringdistance

import com.darkrockstudios.symspellkt.api.StringDistance

Expand All @@ -7,16 +7,8 @@ import com.darkrockstudios.symspellkt.api.StringDistance
* sequences. Informally, the Damerau–Levenshtein distance between two words is the minimum number
* of operations (consisting of insertions, deletions or substitutions of a single character, or
* transposition of two adjacent characters) required to change one word into the other.
*
* In this variant of DamerauLevenshteinDistance, it has different weights associated to each
* action.
*/
class WeightedDamerauLevenshteinDistance(
private val deletionWeight: Double = 0.8,
private val insertionWeight: Double = 1.01,
private val replaceWeight: Double = 0.9,
private val transpositionWeight: Double = 0.7,
) : StringDistance {
class DamerauLevenshteinDistance : StringDistance {

override fun getDistance(w1: String, w2: String): Double {
if (w1 == w2) {
Expand All @@ -39,10 +31,10 @@ class WeightedDamerauLevenshteinDistance(

// Step 2
for (i in w2.length downTo 0) {
d[i][0] = i * insertionWeight // Add insertion weight
d[i][0] = i * INSERTION // Add insertion weight
}
for (j in w1.length downTo 0) {
d[0][j] = j * deletionWeight
d[0][j] = j * DELETION
}

for (i in 1..w2.length) {
Expand All @@ -53,12 +45,12 @@ class WeightedDamerauLevenshteinDistance(
val cost = getReplaceCost(target_i, source_j)

var min = min(
d[i - 1][j] + insertionWeight, //Insertion
d[i][j - 1] + deletionWeight, //Deltion
d[i - 1][j] + INSERTION, //Insertion
d[i][j - 1] + DELETION, //Deltion
d[i - 1][j - 1] + cost
) //Replacement
if (isTransposition(i, j, w1, w2)) {
min = kotlin.math.min(min, d[i - 2][j - 2] + transpositionWeight) // transpose
min = kotlin.math.min(min, d[i - 2][j - 2] + TRANSPORTATION) // transpose
}
d[i][j] = min
}
Expand Down Expand Up @@ -87,9 +79,16 @@ class WeightedDamerauLevenshteinDistance(

private fun getReplaceCost(aI: Char, bJ: Char): Double {
return if (aI != bJ) {
replaceWeight
REPLACEMENT
} else {
0.0
}
}

companion object {
private const val DELETION: Double = 1.0
private const val INSERTION: Double = 1.0
private const val REPLACEMENT: Double = 1.0
private const val TRANSPORTATION: Double = 1.0
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package com.darkrockstudios.symspellkt.common.stringdistance

import com.darkrockstudios.symspellkt.api.StringDistance
import kotlin.math.min

/**
* A pure Levenshtein Distance implementation.
*/
class LevenshteinDistance : StringDistance {
override fun getDistance(lhs: String, rhs: String): Double {
if (lhs == rhs) {
return 0.0
}
if (lhs.isEmpty()) {
return rhs.length.toDouble()
}
if (rhs.isEmpty()) {
return lhs.length.toDouble()
}

val lhsLength = lhs.length + 1
val rhsLength = rhs.length + 1

var cost = Array(lhsLength) { it.toDouble() }
var newCost = Array(lhsLength) { 0.0 }

for (i in 1..<rhsLength) {
newCost[0] = i.toDouble()

for (j in 1..<lhsLength) {
val match = if (lhs[j - 1] == rhs[i - 1]) 0 else 1

val costReplace = cost[j - 1] + match
val costInsert = cost[j] + 1
val costDelete = newCost[j - 1] + 1

newCost[j] = min(min(costInsert, costDelete), costReplace)
}

val swap = cost
cost = newCost
newCost = swap
}

return cost[lhsLength - 1]
}

override fun getDistance(w1: String, w2: String, maxEditDistance: Double): Double {
val distance = getDistance(w1, w2)
if (distance > maxEditDistance) {
return -1.0
}
return distance
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,25 @@ import com.darkrockstudios.symspellkt.api.SpellChecker
import com.darkrockstudios.symspellkt.common.Murmur3HashFunction
import com.darkrockstudios.symspellkt.common.SpellCheckSettings
import com.darkrockstudios.symspellkt.common.Verbosity
import com.darkrockstudios.symspellkt.common.WeightedDamerauLevenshteinDistance
import com.darkrockstudios.symspellkt.common.stringdistance.DamerauLevenshteinDistance

fun createSymSpellChecker(
settings: SpellCheckSettings? = null,
): SpellChecker {
val spellCheckSettings = settings ?: SpellCheckSettings(
countThreshold = 1,
deletionWeight = 1.0,
insertionWeight = 1.0,
replaceWeight = 1.0,
maxEditDistance = 2.0,
transpositionWeight = 1.0,
topK = 5,
prefixLength = 10,
verbosity = Verbosity.ALL,
)

val weightedDamerauLevenshteinDistance =
WeightedDamerauLevenshteinDistance(
spellCheckSettings.deletionWeight,
spellCheckSettings.insertionWeight,
spellCheckSettings.replaceWeight,
spellCheckSettings.transpositionWeight,
)
val damerauLevenshteinDistance = DamerauLevenshteinDistance()
val dataHolder = InMemoryDataHolder(spellCheckSettings, Murmur3HashFunction())

val symSpellCheck = SymSpellCheck(
dataHolder,
weightedDamerauLevenshteinDistance,
damerauLevenshteinDistance,
spellCheckSettings
)

Expand Down
34 changes: 20 additions & 14 deletions library/src/jvmTest/kotlin/symspellkt/AccuracyTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ package symspellkt

import com.darkrockstudios.symspellkt.api.DataHolder
import com.darkrockstudios.symspellkt.api.SpellChecker
import com.darkrockstudios.symspellkt.api.StringDistance
import com.darkrockstudios.symspellkt.common.DictionaryItem
import com.darkrockstudios.symspellkt.common.Murmur3HashFunction
import com.darkrockstudios.symspellkt.common.SpellCheckSettings
import com.darkrockstudios.symspellkt.common.SuggestionItem
import com.darkrockstudios.symspellkt.common.WeightedDamerauLevenshteinDistance
import com.darkrockstudios.symspellkt.common.stringdistance.DamerauLevenshteinDistance
import com.darkrockstudios.symspellkt.common.stringdistance.LevenshteinDistance
import com.darkrockstudios.symspellkt.exception.SpellCheckException
import com.darkrockstudios.symspellkt.impl.InMemoryDataHolder
import com.darkrockstudios.symspellkt.impl.SymSpellCheck
Expand Down Expand Up @@ -157,12 +157,11 @@ class AccuracyTest {
fun testAccuracy() {
val accuracyTest = AccuracyTest()

println("========= SymSpell =============================")
println("========= Pure DamerauLevenshteinDistance =============================")
//Basic
var spellCheckSettings = SpellCheckSettings(
countThreshold = 0,
prefixLength = 40,
maxEditDistance = 2.0
)

var dataHolder: DataHolder = InMemoryDataHolder(
Expand All @@ -172,21 +171,28 @@ class AccuracyTest {

val spellChecker: SpellChecker = SymSpellCheck(
dataHolder,
accuracyTest.getStringDistance(spellCheckSettings),
DamerauLevenshteinDistance(),
spellCheckSettings
)
accuracyTest.run(spellChecker)
}

private fun getStringDistance(
spellCheckSettings: SpellCheckSettings,
): StringDistance {
return WeightedDamerauLevenshteinDistance(
spellCheckSettings.deletionWeight,
spellCheckSettings.insertionWeight,
spellCheckSettings.replaceWeight,
spellCheckSettings.transpositionWeight,
println("========= Pure Levenshtein =============================")
spellCheckSettings = SpellCheckSettings(
countThreshold = 0,
prefixLength = 40,
)
dataHolder = InMemoryDataHolder(
spellCheckSettings,
Murmur3HashFunction()
)
val pureLevenshteinSpellChecker: SpellChecker = SymSpellCheck(
dataHolder,
LevenshteinDistance(),
spellCheckSettings
)
accuracyTest.run(pureLevenshteinSpellChecker)
println("==================================================")

}

companion object {
Expand Down
17 changes: 4 additions & 13 deletions library/src/jvmTest/kotlin/symspellkt/GermanLangSpellChecker.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package symspellkt

import com.darkrockstudios.symspellkt.api.DataHolder
import com.darkrockstudios.symspellkt.common.*
import com.darkrockstudios.symspellkt.common.stringdistance.DamerauLevenshteinDistance
import com.darkrockstudios.symspellkt.exception.SpellCheckException
import com.darkrockstudios.symspellkt.impl.InMemoryDataHolder
import com.darkrockstudios.symspellkt.impl.SymSpellCheck
Expand All @@ -19,7 +20,7 @@ class GermanLangSpellChecker {
lateinit var dataHolder1: DataHolder
lateinit var dataHolder2: DataHolder
lateinit var symSpellCheck: SymSpellCheck
lateinit var weightedDamerauLevenshteinDistance: WeightedDamerauLevenshteinDistance
lateinit var damerauLevenshteinDistance: DamerauLevenshteinDistance

@Before
@Throws(IOException::class, SpellCheckException::class)
Expand All @@ -28,30 +29,20 @@ class GermanLangSpellChecker {

val spellCheckSettings = SpellCheckSettings(
countThreshold = 1,
deletionWeight = 1.0,
insertionWeight = 1.0,
replaceWeight = 1.0,
maxEditDistance = 2.0,
transpositionWeight = 1.0,
topK = 5,
prefixLength = 10,
verbosity = Verbosity.ALL,
)

weightedDamerauLevenshteinDistance =
WeightedDamerauLevenshteinDistance(
spellCheckSettings.deletionWeight,
spellCheckSettings.insertionWeight,
spellCheckSettings.replaceWeight,
spellCheckSettings.transpositionWeight,
)
damerauLevenshteinDistance = DamerauLevenshteinDistance()

dataHolder1 = InMemoryDataHolder(spellCheckSettings, Murmur3HashFunction())
dataHolder2 = InMemoryDataHolder(spellCheckSettings, Murmur3HashFunction())

symSpellCheck = SymSpellCheck(
dataHolder1,
weightedDamerauLevenshteinDistance,
damerauLevenshteinDistance,
spellCheckSettings
)

Expand Down
36 changes: 36 additions & 0 deletions library/src/jvmTest/kotlin/symspellkt/LevenshteinDistanceTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package symspellkt

import com.darkrockstudios.symspellkt.common.stringdistance.LevenshteinDistance
import org.junit.Before
import kotlin.test.Test
import kotlin.test.assertEquals


class LevenshteinDistanceTest {
private lateinit var levenshteinDistance: LevenshteinDistance

@Before
fun setup() {
levenshteinDistance = LevenshteinDistance()
}

@Test
fun `Distance Test`() {
testDistance("", "", 0)
testDistance("1", "1", 0)
testDistance("1", "2", 1)
testDistance("12", "12", 0)
testDistance("123", "12", 1)
testDistance("1234", "1", 3)
testDistance("1234", "1233", 1)
testDistance("", "12345", 5)
testDistance("kitten", "mittens", 2)
testDistance("canada", "canad", 1)
testDistance("canad", "canada", 1)
}

private fun testDistance(a: String, b: String, expectedDistance: Int) {
val d = levenshteinDistance.getDistance(a, b)
assertEquals(expectedDistance.toDouble(), d, "Distance did not match for `$a` and `$b`")
}
}
Loading
Loading