From 347c74632913d038b473ab74b5e71163df29879d Mon Sep 17 00:00:00 2001 From: camdenorrb Date: Mon, 17 Jan 2022 06:51:35 -0600 Subject: [PATCH] Fix the majority of my tests --- .../lang/crescent/vm/CrescentVM.kt | 75 ++++++++++--------- .../lang/crescent/CrescentLexerTests.kt | 28 +++++-- .../lang/crescent/CrescentParserTests.kt | 61 +++++++++++---- .../lang/crescent/CrescentVMTests.kt | 24 ++++-- .../lang/crescent/data/TestCode.kt | 7 +- 5 files changed, 130 insertions(+), 65 deletions(-) diff --git a/src/main/kotlin/dev/twelveoclock/lang/crescent/vm/CrescentVM.kt b/src/main/kotlin/dev/twelveoclock/lang/crescent/vm/CrescentVM.kt index 0b0d4e6..1aced51 100644 --- a/src/main/kotlin/dev/twelveoclock/lang/crescent/vm/CrescentVM.kt +++ b/src/main/kotlin/dev/twelveoclock/lang/crescent/vm/CrescentVM.kt @@ -43,7 +43,13 @@ class CrescentVM(val files: List, val mainFile: Node.File) { // TODO: Account for default params checkEquals(function.params.size, args.size) - val functionContext = context.copy() + // Need to manually copy since .copy uses the same map instance, .toMutableMap() clones properly + val functionContext = BlockContext( + context.file, + context.holder, + context.parameters.toMutableMap(), + context.variables.toMutableMap() + ) function.params.forEachIndexed { index, parameter -> @@ -53,21 +59,24 @@ class CrescentVM(val files: List, val mainFile: Node.File) { "Parameter type doesn't match argument: $parameterType != ${findType(args[index])}" } - functionContext.parameters[parameter.name] = Variable(parameter.name, Instance(findType(arg), arg), false) + functionContext.parameters[parameter.name] = Variable(parameter.name, Instance(findType(arg), arg), true) } - return runBlock(function.innerCode, functionContext) + return when (val result = runBlock(function.innerCode, functionContext)) { + is Node.Return -> runNode(result.expression, functionContext) + else -> result + } } // TODO: Have a return value fun runBlock(block: Node.Statement.Block, context: BlockContext): Node { block.nodes.forEachIndexed { index, node -> - // If is last node in the block - if (index + 1 == block.nodes.size || node is Node.Return) { - return runNode(node, context) - } else { - runNode(node, context) + + val result = runNode(node, context) + + if (index + 1 == block.nodes.size || result is Node.Return) { + return result } } @@ -108,13 +117,13 @@ class CrescentVM(val files: List, val mainFile: Node.File) { } is Node.Return -> { - return runNode(node.expression, context) + return node } // TODO: Account for operator overloading is Node.GetCall -> { - val arrayNode = (context.parameters[node.identifier] + val arrayNode = (context.parameters[node.identifier]?.instance?.value ?: context.variables.getValue(node.identifier).instance.value) as Node.Array return arrayNode.values[(runNode(node.arguments[0], context) as Primitive.Number).toI32().data] @@ -142,13 +151,18 @@ class CrescentVM(val files: List, val mainFile: Node.File) { } is Node.Statement.If -> { - return if ((runNode(node.predicate, context) as Primitive.Boolean).data) { - runBlock(node.block, context) - } else { - node.elseBlock?.let { - runBlock(it, context) + + val result = + if ((runNode(node.predicate, context) as Primitive.Boolean).data) { + runBlock(node.block, context) } - } ?: Type.unit + else { + node.elseBlock?.let { + runBlock(it, context) + } + } ?: Type.unit + + return result } is Node.Statement.While -> { @@ -264,7 +278,8 @@ class CrescentVM(val files: List, val mainFile: Node.File) { is Node.Variable.Basic, is Node.Variable.Local -> { - runVariable(node as Node.Variable, context) + val variable = runVariable(node as Node.Variable, context) + context.variables[variable.name] = variable } else -> error("Unexpected node: $node") @@ -412,8 +427,7 @@ class CrescentVM(val files: List, val mainFile: Node.File) { is Node.GetCall -> { checkEquals(1, pop2.arguments.size) val index = (pop2.arguments.first() as Primitive.Number).toI32().data - (context.variables.getValue(pop2.identifier).instance.value as Node.Array).values[index] = - value + (context.variables.getValue(pop2.identifier).instance.value as Node.Array).values[index] = value } is Node.Identifier -> { @@ -622,14 +636,7 @@ class CrescentVM(val files: List, val mainFile: Node.File) { "sqrt" -> { checkEquals(1, node.arguments.size) - return Primitive.Number.F64( - sqrt( - (runNode( - node.arguments[0], - context - ) as Primitive.Number).toF64().data - ) - ) + return Primitive.Number.F64(sqrt((runNode(node.arguments[0], context) as Primitive.Number).toF64().data)) } "sin" -> { @@ -639,14 +646,7 @@ class CrescentVM(val files: List, val mainFile: Node.File) { "round" -> { checkEquals(1, node.arguments.size) - return Primitive.Number.F64( - round( - (runNode( - node.arguments[0], - context - ) as Primitive.Number).toF64().data - ) - ) + return Primitive.Number.F64(round((runNode(node.arguments[0], context) as Primitive.Number).toF64().data)) } "print" -> { @@ -726,7 +726,7 @@ class CrescentVM(val files: List, val mainFile: Node.File) { check(parameter is Node.Parameter.Basic) { "Crescent doesn't support parameters with default values yet." } - check(parameter.type == Type.any || parameter.type == findType(argumentValues[index])) { + checkIsSameType(parameter.type, argumentValues[index]) { "Parameter ${parameter.name} had an argument of type ${findType(argumentValues[index])}, expected ${parameter.type}" } } @@ -791,7 +791,8 @@ class CrescentVM(val files: List, val mainFile: Node.File) { Type.Basic(value.identifier) } else { - error("Unexpected value: ${value::class}") + // TODO: Resolve the function return type + Type.any } } diff --git a/src/test/kotlin/dev/twelveoclock/lang/crescent/CrescentLexerTests.kt b/src/test/kotlin/dev/twelveoclock/lang/crescent/CrescentLexerTests.kt index c2d9cee..2e707b2 100644 --- a/src/test/kotlin/dev/twelveoclock/lang/crescent/CrescentLexerTests.kt +++ b/src/test/kotlin/dev/twelveoclock/lang/crescent/CrescentLexerTests.kt @@ -48,7 +48,8 @@ internal class CrescentLexerTests { assertContentEquals( listOf( - FUN, Key("main"), Parenthesis.OPEN, Key("args"), TYPE_PREFIX, SquareBracket.OPEN, Key("String"), SquareBracket.CLOSE, Parenthesis.CLOSE, Bracket.OPEN, + + FUN, Key("test1"), Parenthesis.OPEN, Key("args"), TYPE_PREFIX, SquareBracket.OPEN, Key("String"), SquareBracket.CLOSE, Parenthesis.CLOSE, Bracket.OPEN, IF, Parenthesis.OPEN, Key("args"), SquareBracket.OPEN, Data.Number(0.toByte()), SquareBracket.CLOSE, EQUALS_COMPARE, Data.String("true"), Parenthesis.CLOSE, Bracket.OPEN, Key("println"), Parenthesis.OPEN, Data.String("Meow"), Parenthesis.CLOSE, Bracket.CLOSE, @@ -56,6 +57,21 @@ internal class CrescentLexerTests { Key("println"), Parenthesis.OPEN, Data.String("Hiss"), Parenthesis.CLOSE, Bracket.CLOSE, Bracket.CLOSE, + + FUN, Key("test2"), Parenthesis.OPEN, Key("args"), TYPE_PREFIX, SquareBracket.OPEN, Key("String"), SquareBracket.CLOSE, Parenthesis.CLOSE, RETURN, Key("String"), Bracket.OPEN, + IF, Parenthesis.OPEN, Key("args"), SquareBracket.OPEN, Data.Number(0.toByte()), SquareBracket.CLOSE, EQUALS_COMPARE, Data.String("true"), Parenthesis.CLOSE, Bracket.OPEN, + RETURN, Data.String("Meow"), + Bracket.CLOSE, + ELSE, Bracket.OPEN, + RETURN, Data.String("Hiss"), + Bracket.CLOSE, + Key("println"), Parenthesis.OPEN, Data.String("This shouldn't be printed"), Parenthesis.CLOSE, + Bracket.CLOSE, + + FUN, Key("main"), Parenthesis.OPEN, Key("args"), TYPE_PREFIX, SquareBracket.OPEN, Key("String"), SquareBracket.CLOSE, Parenthesis.CLOSE, Bracket.OPEN, + Key("test1"), Parenthesis.OPEN, Key("args"), Parenthesis.CLOSE, + Key("println"), Parenthesis.OPEN, Key("test2"), Parenthesis.OPEN, Key("args"), Parenthesis.CLOSE, Parenthesis.CLOSE, + Bracket.CLOSE ), tokens ) @@ -205,7 +221,7 @@ internal class CrescentLexerTests { OBJECT, Key("Constants"), Bracket.OPEN, CONST, Key("thing2"), ASSIGN, Data.String("Meow"), - FUN, Key("printThing"), Parenthesis.OPEN, Parenthesis.CLOSE, Bracket.OPEN, + FUN, Key("printThings"), Parenthesis.OPEN, Parenthesis.CLOSE, Bracket.OPEN, Key("println"), Parenthesis.OPEN, Key("thing1"), Parenthesis.CLOSE, Key("println"), Parenthesis.OPEN, Key("thing2"), Parenthesis.CLOSE, Bracket.CLOSE, @@ -213,6 +229,8 @@ internal class CrescentLexerTests { FUN, Key("main"), Bracket.OPEN, Key("Constants"), DOT, Key("printThings"), Parenthesis.OPEN, Parenthesis.CLOSE, + Key("println"), Parenthesis.OPEN, Key("thing1"), Parenthesis.CLOSE, + Key("println"), Parenthesis.OPEN, Key("Constants"), DOT, Key("thing2"), Parenthesis.CLOSE, Bracket.CLOSE ), tokens @@ -228,7 +246,7 @@ internal class CrescentLexerTests { listOf( STRUCT, Key("Example"), Parenthesis.OPEN, - VAL, Key("aNumber"), TYPE_PREFIX, Key("Int"), Data.Comment("New lines makes commas redundant"), + VAL, Key("aNumber"), TYPE_PREFIX, Key("I32"), Data.Comment("New lines makes commas redundant"), VAL, Key("aValue1"), Key("aValue2"), ASSIGN, Data.String(""), Data.Comment("Multi declaration of same type, can all be set to one or multiple default values"), Parenthesis.CLOSE, @@ -243,10 +261,10 @@ internal class CrescentLexerTests { Data.Comment("Can't use self in static syntax"), IMPL, Modifier.STATIC, Key("Example"), Bracket.OPEN, - FUN, Key("add"), Parenthesis.OPEN, Key("value1"), Key("value2"), TYPE_PREFIX, Key("Int"), Parenthesis.CLOSE, RETURN, Key("Int"), Bracket.OPEN, + FUN, Key("add"), Parenthesis.OPEN, Key("value1"), Key("value2"), TYPE_PREFIX, Key("I32"), Parenthesis.CLOSE, RETURN, Key("I32"), Bracket.OPEN, RETURN, Key("value1"), ADD, Key("value2"), Bracket.CLOSE, - FUN, Key("sub"), Parenthesis.OPEN, Key("value1"), Key("value2"), TYPE_PREFIX, Key("Int"), Parenthesis.CLOSE, RETURN, Key("Int"), Bracket.OPEN, + FUN, Key("sub"), Parenthesis.OPEN, Key("value1"), Key("value2"), TYPE_PREFIX, Key("I32"), Parenthesis.CLOSE, RETURN, Key("I32"), Bracket.OPEN, RETURN, Key("value1"), SUB, Key("value2"), Bracket.CLOSE, Bracket.CLOSE, diff --git a/src/test/kotlin/dev/twelveoclock/lang/crescent/CrescentParserTests.kt b/src/test/kotlin/dev/twelveoclock/lang/crescent/CrescentParserTests.kt index 03f38b3..a286683 100644 --- a/src/test/kotlin/dev/twelveoclock/lang/crescent/CrescentParserTests.kt +++ b/src/test/kotlin/dev/twelveoclock/lang/crescent/CrescentParserTests.kt @@ -55,31 +55,60 @@ internal class CrescentParserTests { fun ifStatement() { val tokens = CrescentLexer.invoke(TestCode.ifStatement) + val parsed = CrescentParser.invoke(Path.of("example.crescent"), tokens) val mainFunction = assertNotNull( - CrescentParser.invoke(Path.of("example.crescent"), tokens).mainFunction, + parsed.mainFunction, "No main function found" ) - assertContentEquals( - listOf(Parameter.Basic("args", Type.Array(Type.Basic("String")))), - mainFunction.params - ) assertContentEquals( listOf( Statement.If( - Expression(listOf( + predicate = Expression(listOf( GetCall("args", listOf(I8(0))), String("true"), EQUALS_COMPARE )), - Statement.Block(listOf( + block = Statement.Block(listOf( IdentifierCall("println", listOf(String("Meow"))) )), - Statement.Block(listOf( + elseBlock = Statement.Block(listOf( IdentifierCall("println", listOf(String("Hiss"))) )), ), ), + parsed.functions["test1"]!!.innerCode.nodes, + ) + + assertContentEquals( + listOf( + Statement.If( + predicate = Expression(listOf( + GetCall("args", listOf(I8(0))), String("true"), EQUALS_COMPARE + )), + block = Statement.Block(listOf( + Return(String("Meow")) + )), + elseBlock = Statement.Block(listOf( + Return(String("Hiss")) + )), + ), + IdentifierCall("println", listOf(String("This shouldn't be printed"))) + ), + parsed.functions["test2"]!!.innerCode.nodes, + ) + + + assertContentEquals( + listOf(Parameter.Basic("args", Type.Array(Type.Basic("String")))), + mainFunction.params + ) + + assertContentEquals( + listOf( + IdentifierCall("test1", listOf(Identifier("args"))), + IdentifierCall("println", listOf(IdentifierCall("test2", listOf(Identifier("args"))))) + ), mainFunction.innerCode.nodes, ) } @@ -343,11 +372,13 @@ internal class CrescentParserTests { IdentifierCall("println", listOf(Identifier("thing1"))), IdentifierCall("println", listOf(Identifier("thing2"))) ), - constantsObject.functions["printThing"]!!.innerCode.nodes + constantsObject.functions["printThings"]!!.innerCode.nodes ) assertContentEquals( listOf( - DotChain(listOf(Identifier("Constants"), IdentifierCall("printThings"))) + DotChain(listOf(Identifier("Constants"), IdentifierCall("printThings"))), + IdentifierCall("println", listOf(Identifier("thing1"))), + IdentifierCall("println", listOf(DotChain(listOf(Identifier("Constants"), Identifier("thing2"))))), ), mainFunction.innerCode.nodes ) @@ -363,7 +394,7 @@ internal class CrescentParserTests { assertContentEquals( listOf( Struct("Example", listOf( - Variable.Basic("aNumber", Type.Basic("Int"), Expression(emptyList()), true, Visibility.PUBLIC), + Variable.Basic("aNumber", Type.Basic("I32"), Expression(emptyList()), true, Visibility.PUBLIC), Variable.Basic("aValue1", Type.Implicit, String(""), true, Visibility.PUBLIC), Variable.Basic("aValue2", Type.Implicit, String(""), true, Visibility.PUBLIC), )) @@ -426,8 +457,8 @@ internal class CrescentParserTests { name = "add", modifiers = emptyList(), visibility = Visibility.PUBLIC, - params = listOf(Parameter.Basic("value1", Type.Basic("Int")), Parameter.Basic("value2", Type.Basic("Int"))), - returnType = Type.Basic("Int"), + params = listOf(Parameter.Basic("value1", Type.Basic("I32")), Parameter.Basic("value2", Type.Basic("I32"))), + returnType = Type.Basic("I32"), innerCode = Statement.Block(listOf( Return(Expression(listOf(Identifier("value1"), Identifier("value2"), ADD))) )) @@ -436,8 +467,8 @@ internal class CrescentParserTests { name = "sub", modifiers = emptyList(), visibility = Visibility.PUBLIC, - params = listOf(Parameter.Basic("value1", Type.Basic("Int")), Parameter.Basic("value2", Type.Basic("Int"))), - returnType = Type.Basic("Int"), + params = listOf(Parameter.Basic("value1", Type.Basic("I32")), Parameter.Basic("value2", Type.Basic("I32"))), + returnType = Type.Basic("I32"), innerCode = Statement.Block(listOf( Return(Expression(listOf(Identifier("value1"), Identifier("value2"), SUB))) )) diff --git a/src/test/kotlin/dev/twelveoclock/lang/crescent/CrescentVMTests.kt b/src/test/kotlin/dev/twelveoclock/lang/crescent/CrescentVMTests.kt index 4c18f69..674f946 100644 --- a/src/test/kotlin/dev/twelveoclock/lang/crescent/CrescentVMTests.kt +++ b/src/test/kotlin/dev/twelveoclock/lang/crescent/CrescentVMTests.kt @@ -67,9 +67,11 @@ internal class CrescentVMTests { val file = CrescentParser.invoke(Path("example.crescent"), CrescentLexer.invoke(TestCode.argsHelloWorld)) + println(file.mainFunction!!.innerCode.nodes) + assertEquals( "Hello World\n", - collectSystemOut { + collectSystemOut(true) { CrescentVM(listOf(file), file).invoke(listOf("Hello World")) } ) @@ -90,6 +92,8 @@ internal class CrescentVMTests { -5 Meow Meow + Cats + Basic(Unit) """.trimIndent(), collectSystemOut(true) { @@ -104,14 +108,22 @@ internal class CrescentVMTests { val file = CrescentParser.invoke(Path("example.crescent"), CrescentLexer.invoke(TestCode.ifStatement)) assertEquals( - "Meow\n", - collectSystemOut { + """ + Meow + Meow + + """.trimIndent(), + collectSystemOut(true) { CrescentVM(listOf(file), file).invoke(listOf("true")) } ) assertEquals( - "Hiss\n", + """ + Hiss + Hiss + + """.trimIndent(), collectSystemOut { CrescentVM(listOf(file), file).invoke(listOf("false")) } @@ -129,7 +141,7 @@ internal class CrescentVMTests { Meow """.trimIndent(), - collectSystemOut { + collectSystemOut(true) { fakeUserInput("true") { CrescentVM(listOf(file), file).invoke() } @@ -215,6 +227,8 @@ internal class CrescentVMTests { """ Mew Meow + Mew + Meow """.trimIndent(), collectSystemOut(true) { diff --git a/src/test/kotlin/dev/twelveoclock/lang/crescent/data/TestCode.kt b/src/test/kotlin/dev/twelveoclock/lang/crescent/data/TestCode.kt index 13c604e..c3e0bb6 100644 --- a/src/test/kotlin/dev/twelveoclock/lang/crescent/data/TestCode.kt +++ b/src/test/kotlin/dev/twelveoclock/lang/crescent/data/TestCode.kt @@ -134,8 +134,8 @@ internal object TestCode { const val ifStatement = """ fun test1(args: [String]) { - if (args[0] == "true") { - println("Meow") + if (args[0] == "true") { + println("Meow") } else { println("Hiss") @@ -483,13 +483,14 @@ internal object TestCode { if (n >= 0){ triangle(n-1, k+1); - + var x: I32 = 0; var y: I32 = 0; while (x < k){ print(" ") x = x + 1 } + while (y < n){ print("* ") y = y + 1