diff --git a/chapi-ast-rust/src/main/kotlin/chapi/ast/rustast/RustAstBaseListener.kt b/chapi-ast-rust/src/main/kotlin/chapi/ast/rustast/RustAstBaseListener.kt index aa4286d1..379217e7 100644 --- a/chapi-ast-rust/src/main/kotlin/chapi/ast/rustast/RustAstBaseListener.kt +++ b/chapi-ast-rust/src/main/kotlin/chapi/ast/rustast/RustAstBaseListener.kt @@ -357,7 +357,7 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis val functionName = ctx.identifier().text val function = CodeFunction( Name = functionName, - Package = codeContainer.PackageName, + Package = packageName, Position = buildPosition(ctx), Parameters = buildParameters(ctx.functionParameters()), ReturnType = possibleReturnType, @@ -370,7 +370,7 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis val functionName = ctx.identifier().text val function = CodeFunction( Name = functionName, - Package = codeContainer.PackageName, + Package = packageName, Position = buildPosition(ctx), Parameters = buildParameters(ctx.functionParameters()), ReturnType = possibleReturnType, @@ -507,7 +507,7 @@ open class RustAstBaseListener(private val fileName: String) : RustParserBaseLis return listOf( CodeDataStruct().apply { NodeName = fileName.substringBeforeLast('.') - Module = if (lastModule == "tests") lastModule else "" + Module = lastModule Type = DataStructType.OBJECT Package = codeContainer.PackageName FilePath = codeContainer.FullName diff --git a/chapi-ast-rust/src/test/kotlin/chapi/ast/rustast/RustAnalyserTest.kt b/chapi-ast-rust/src/test/kotlin/chapi/ast/rustast/RustAnalyserTest.kt index f7faab61..ea9f72f4 100644 --- a/chapi-ast-rust/src/test/kotlin/chapi/ast/rustast/RustAnalyserTest.kt +++ b/chapi-ast-rust/src/test/kotlin/chapi/ast/rustast/RustAnalyserTest.kt @@ -1,8 +1,5 @@ package chapi.ast.rustast -import kotlinx.serialization.encodeToString -import kotlinx.serialization.json.Json -import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Test import java.io.File import kotlin.test.assertEquals @@ -72,6 +69,48 @@ internal class RustAnalyserTest { assertEquals(position.StopLine, 4) } + @Test + fun should_success_build_position_for_testing() { + val testCode = """ + use std::sync::Arc; + + pub use embedding::Embedding; + pub use embedding::Semantic; + pub use embedding::semantic::SemanticError; + + pub fn init_semantic(model: Vec, tokenizer_data: Vec) -> Result, SemanticError> { + let result = Semantic::init_semantic(model, tokenizer_data)?; + Ok(Arc::new(result)) + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + #[cfg_attr(feature = "ci", ignore)] + fn test_init_semantic() { + let model = std::fs::read("../model/model.onnx").unwrap(); + let tokenizer_data = std::fs::read("../model/tokenizer.json").unwrap(); + + let semantic = init_semantic(model, tokenizer_data).unwrap(); + let embedding = semantic.embed("hello world").unwrap(); + assert_eq!(embedding.len(), 128); + } + } + """.trimIndent() + + + val container = rustAnalyser.analysis(testCode, "lib.rs") + val functions = container.DataStructures.map { dataStruct -> + dataStruct.Functions.filter { function -> function.Annotations.any { it.Name == "test" } } + }.flatten() + + assertEquals(functions.size, 1) + val firstFunction = functions[0] + assertEquals(firstFunction.Position.StartLine, 18) + } + @Test fun allGrammarUnderResources() { val content = this::class.java.getResource("/grammar")!!