Skip to content

Commit

Permalink
Fix the majority of my tests
Browse files Browse the repository at this point in the history
  • Loading branch information
camdenorrb committed Jan 17, 2022
1 parent 7b630f6 commit 347c746
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 65 deletions.
75 changes: 38 additions & 37 deletions src/main/kotlin/dev/twelveoclock/lang/crescent/vm/CrescentVM.kt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ class CrescentVM(val files: List<Node.File>, val mainFile: Node.File) {
// TODO: Account for default params
checkEquals(function.params.size, args.size)

val functionContext = context.copy()
// Need to manually copy since .copy uses the same map instance, .toMutableMap() clones properly
val functionContext = BlockContext(
context.file,
context.holder,
context.parameters.toMutableMap(),
context.variables.toMutableMap()
)

function.params.forEachIndexed { index, parameter ->

Expand All @@ -53,21 +59,24 @@ class CrescentVM(val files: List<Node.File>, val mainFile: Node.File) {
"Parameter type doesn't match argument: $parameterType != ${findType(args[index])}"
}

functionContext.parameters[parameter.name] = Variable(parameter.name, Instance(findType(arg), arg), false)
functionContext.parameters[parameter.name] = Variable(parameter.name, Instance(findType(arg), arg), true)
}

return runBlock(function.innerCode, functionContext)
return when (val result = runBlock(function.innerCode, functionContext)) {
is Node.Return -> runNode(result.expression, functionContext)
else -> result
}
}

// TODO: Have a return value
fun runBlock(block: Node.Statement.Block, context: BlockContext): Node {

block.nodes.forEachIndexed { index, node ->
// If is last node in the block
if (index + 1 == block.nodes.size || node is Node.Return) {
return runNode(node, context)
} else {
runNode(node, context)

val result = runNode(node, context)

if (index + 1 == block.nodes.size || result is Node.Return) {
return result
}
}

Expand Down Expand Up @@ -108,13 +117,13 @@ class CrescentVM(val files: List<Node.File>, val mainFile: Node.File) {
}

is Node.Return -> {
return runNode(node.expression, context)
return node
}

// TODO: Account for operator overloading
is Node.GetCall -> {

val arrayNode = (context.parameters[node.identifier]
val arrayNode = (context.parameters[node.identifier]?.instance?.value
?: context.variables.getValue(node.identifier).instance.value) as Node.Array

return arrayNode.values[(runNode(node.arguments[0], context) as Primitive.Number).toI32().data]
Expand Down Expand Up @@ -142,13 +151,18 @@ class CrescentVM(val files: List<Node.File>, val mainFile: Node.File) {
}

is Node.Statement.If -> {
return if ((runNode(node.predicate, context) as Primitive.Boolean).data) {
runBlock(node.block, context)
} else {
node.elseBlock?.let {
runBlock(it, context)

val result =
if ((runNode(node.predicate, context) as Primitive.Boolean).data) {
runBlock(node.block, context)
}
} ?: Type.unit
else {
node.elseBlock?.let {
runBlock(it, context)
}
} ?: Type.unit

return result
}

is Node.Statement.While -> {
Expand Down Expand Up @@ -264,7 +278,8 @@ class CrescentVM(val files: List<Node.File>, val mainFile: Node.File) {

is Node.Variable.Basic,
is Node.Variable.Local -> {
runVariable(node as Node.Variable, context)
val variable = runVariable(node as Node.Variable, context)
context.variables[variable.name] = variable
}

else -> error("Unexpected node: $node")
Expand Down Expand Up @@ -412,8 +427,7 @@ class CrescentVM(val files: List<Node.File>, val mainFile: Node.File) {
is Node.GetCall -> {
checkEquals(1, pop2.arguments.size)
val index = (pop2.arguments.first() as Primitive.Number).toI32().data
(context.variables.getValue(pop2.identifier).instance.value as Node.Array).values[index] =
value
(context.variables.getValue(pop2.identifier).instance.value as Node.Array).values[index] = value
}

is Node.Identifier -> {
Expand Down Expand Up @@ -622,14 +636,7 @@ class CrescentVM(val files: List<Node.File>, val mainFile: Node.File) {

"sqrt" -> {
checkEquals(1, node.arguments.size)
return Primitive.Number.F64(
sqrt(
(runNode(
node.arguments[0],
context
) as Primitive.Number).toF64().data
)
)
return Primitive.Number.F64(sqrt((runNode(node.arguments[0], context) as Primitive.Number).toF64().data))
}

"sin" -> {
Expand All @@ -639,14 +646,7 @@ class CrescentVM(val files: List<Node.File>, val mainFile: Node.File) {

"round" -> {
checkEquals(1, node.arguments.size)
return Primitive.Number.F64(
round(
(runNode(
node.arguments[0],
context
) as Primitive.Number).toF64().data
)
)
return Primitive.Number.F64(round((runNode(node.arguments[0], context) as Primitive.Number).toF64().data))
}

"print" -> {
Expand Down Expand Up @@ -726,7 +726,7 @@ class CrescentVM(val files: List<Node.File>, val mainFile: Node.File) {
check(parameter is Node.Parameter.Basic) {
"Crescent doesn't support parameters with default values yet."
}
check(parameter.type == Type.any || parameter.type == findType(argumentValues[index])) {
checkIsSameType(parameter.type, argumentValues[index]) {
"Parameter ${parameter.name} had an argument of type ${findType(argumentValues[index])}, expected ${parameter.type}"
}
}
Expand Down Expand Up @@ -791,7 +791,8 @@ class CrescentVM(val files: List<Node.File>, val mainFile: Node.File) {
Type.Basic(value.identifier)
}
else {
error("Unexpected value: ${value::class}")
// TODO: Resolve the function return type
Type.any
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,30 @@ internal class CrescentLexerTests {

assertContentEquals(
listOf(
FUN, Key("main"), Parenthesis.OPEN, Key("args"), TYPE_PREFIX, SquareBracket.OPEN, Key("String"), SquareBracket.CLOSE, Parenthesis.CLOSE, Bracket.OPEN,

FUN, Key("test1"), Parenthesis.OPEN, Key("args"), TYPE_PREFIX, SquareBracket.OPEN, Key("String"), SquareBracket.CLOSE, Parenthesis.CLOSE, Bracket.OPEN,
IF, Parenthesis.OPEN, Key("args"), SquareBracket.OPEN, Data.Number(0.toByte()), SquareBracket.CLOSE, EQUALS_COMPARE, Data.String("true"), Parenthesis.CLOSE, Bracket.OPEN,
Key("println"), Parenthesis.OPEN, Data.String("Meow"), Parenthesis.CLOSE,
Bracket.CLOSE,
ELSE, Bracket.OPEN,
Key("println"), Parenthesis.OPEN, Data.String("Hiss"), Parenthesis.CLOSE,
Bracket.CLOSE,
Bracket.CLOSE,

FUN, Key("test2"), Parenthesis.OPEN, Key("args"), TYPE_PREFIX, SquareBracket.OPEN, Key("String"), SquareBracket.CLOSE, Parenthesis.CLOSE, RETURN, Key("String"), Bracket.OPEN,
IF, Parenthesis.OPEN, Key("args"), SquareBracket.OPEN, Data.Number(0.toByte()), SquareBracket.CLOSE, EQUALS_COMPARE, Data.String("true"), Parenthesis.CLOSE, Bracket.OPEN,
RETURN, Data.String("Meow"),
Bracket.CLOSE,
ELSE, Bracket.OPEN,
RETURN, Data.String("Hiss"),
Bracket.CLOSE,
Key("println"), Parenthesis.OPEN, Data.String("This shouldn't be printed"), Parenthesis.CLOSE,
Bracket.CLOSE,

FUN, Key("main"), Parenthesis.OPEN, Key("args"), TYPE_PREFIX, SquareBracket.OPEN, Key("String"), SquareBracket.CLOSE, Parenthesis.CLOSE, Bracket.OPEN,
Key("test1"), Parenthesis.OPEN, Key("args"), Parenthesis.CLOSE,
Key("println"), Parenthesis.OPEN, Key("test2"), Parenthesis.OPEN, Key("args"), Parenthesis.CLOSE, Parenthesis.CLOSE,
Bracket.CLOSE
),
tokens
)
Expand Down Expand Up @@ -205,14 +221,16 @@ internal class CrescentLexerTests {

OBJECT, Key("Constants"), Bracket.OPEN,
CONST, Key("thing2"), ASSIGN, Data.String("Meow"),
FUN, Key("printThing"), Parenthesis.OPEN, Parenthesis.CLOSE, Bracket.OPEN,
FUN, Key("printThings"), Parenthesis.OPEN, Parenthesis.CLOSE, Bracket.OPEN,
Key("println"), Parenthesis.OPEN, Key("thing1"), Parenthesis.CLOSE,
Key("println"), Parenthesis.OPEN, Key("thing2"), Parenthesis.CLOSE,
Bracket.CLOSE,
Bracket.CLOSE,

FUN, Key("main"), Bracket.OPEN,
Key("Constants"), DOT, Key("printThings"), Parenthesis.OPEN, Parenthesis.CLOSE,
Key("println"), Parenthesis.OPEN, Key("thing1"), Parenthesis.CLOSE,
Key("println"), Parenthesis.OPEN, Key("Constants"), DOT, Key("thing2"), Parenthesis.CLOSE,
Bracket.CLOSE
),
tokens
Expand All @@ -228,7 +246,7 @@ internal class CrescentLexerTests {
listOf(

STRUCT, Key("Example"), Parenthesis.OPEN,
VAL, Key("aNumber"), TYPE_PREFIX, Key("Int"), Data.Comment("New lines makes commas redundant"),
VAL, Key("aNumber"), TYPE_PREFIX, Key("I32"), Data.Comment("New lines makes commas redundant"),
VAL, Key("aValue1"), Key("aValue2"), ASSIGN, Data.String(""), Data.Comment("Multi declaration of same type, can all be set to one or multiple default values"),
Parenthesis.CLOSE,

Expand All @@ -243,10 +261,10 @@ internal class CrescentLexerTests {

Data.Comment("Can't use self in static syntax"),
IMPL, Modifier.STATIC, Key("Example"), Bracket.OPEN,
FUN, Key("add"), Parenthesis.OPEN, Key("value1"), Key("value2"), TYPE_PREFIX, Key("Int"), Parenthesis.CLOSE, RETURN, Key("Int"), Bracket.OPEN,
FUN, Key("add"), Parenthesis.OPEN, Key("value1"), Key("value2"), TYPE_PREFIX, Key("I32"), Parenthesis.CLOSE, RETURN, Key("I32"), Bracket.OPEN,
RETURN, Key("value1"), ADD, Key("value2"),
Bracket.CLOSE,
FUN, Key("sub"), Parenthesis.OPEN, Key("value1"), Key("value2"), TYPE_PREFIX, Key("Int"), Parenthesis.CLOSE, RETURN, Key("Int"), Bracket.OPEN,
FUN, Key("sub"), Parenthesis.OPEN, Key("value1"), Key("value2"), TYPE_PREFIX, Key("I32"), Parenthesis.CLOSE, RETURN, Key("I32"), Bracket.OPEN,
RETURN, Key("value1"), SUB, Key("value2"),
Bracket.CLOSE,
Bracket.CLOSE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,31 +55,60 @@ internal class CrescentParserTests {
fun ifStatement() {

val tokens = CrescentLexer.invoke(TestCode.ifStatement)
val parsed = CrescentParser.invoke(Path.of("example.crescent"), tokens)

val mainFunction = assertNotNull(
CrescentParser.invoke(Path.of("example.crescent"), tokens).mainFunction,
parsed.mainFunction,
"No main function found"
)

assertContentEquals(
listOf(Parameter.Basic("args", Type.Array(Type.Basic("String")))),
mainFunction.params
)

assertContentEquals(
listOf(
Statement.If(
Expression(listOf(
predicate = Expression(listOf(
GetCall("args", listOf(I8(0))), String("true"), EQUALS_COMPARE
)),
Statement.Block(listOf(
block = Statement.Block(listOf(
IdentifierCall("println", listOf(String("Meow")))
)),
Statement.Block(listOf(
elseBlock = Statement.Block(listOf(
IdentifierCall("println", listOf(String("Hiss")))
)),
),
),
parsed.functions["test1"]!!.innerCode.nodes,
)

assertContentEquals(
listOf(
Statement.If(
predicate = Expression(listOf(
GetCall("args", listOf(I8(0))), String("true"), EQUALS_COMPARE
)),
block = Statement.Block(listOf(
Return(String("Meow"))
)),
elseBlock = Statement.Block(listOf(
Return(String("Hiss"))
)),
),
IdentifierCall("println", listOf(String("This shouldn't be printed")))
),
parsed.functions["test2"]!!.innerCode.nodes,
)


assertContentEquals(
listOf(Parameter.Basic("args", Type.Array(Type.Basic("String")))),
mainFunction.params
)

assertContentEquals(
listOf(
IdentifierCall("test1", listOf(Identifier("args"))),
IdentifierCall("println", listOf(IdentifierCall("test2", listOf(Identifier("args")))))
),
mainFunction.innerCode.nodes,
)
}
Expand Down Expand Up @@ -343,11 +372,13 @@ internal class CrescentParserTests {
IdentifierCall("println", listOf(Identifier("thing1"))),
IdentifierCall("println", listOf(Identifier("thing2")))
),
constantsObject.functions["printThing"]!!.innerCode.nodes
constantsObject.functions["printThings"]!!.innerCode.nodes
)
assertContentEquals(
listOf(
DotChain(listOf(Identifier("Constants"), IdentifierCall("printThings")))
DotChain(listOf(Identifier("Constants"), IdentifierCall("printThings"))),
IdentifierCall("println", listOf(Identifier("thing1"))),
IdentifierCall("println", listOf(DotChain(listOf(Identifier("Constants"), Identifier("thing2"))))),
),
mainFunction.innerCode.nodes
)
Expand All @@ -363,7 +394,7 @@ internal class CrescentParserTests {
assertContentEquals(
listOf(
Struct("Example", listOf(
Variable.Basic("aNumber", Type.Basic("Int"), Expression(emptyList()), true, Visibility.PUBLIC),
Variable.Basic("aNumber", Type.Basic("I32"), Expression(emptyList()), true, Visibility.PUBLIC),
Variable.Basic("aValue1", Type.Implicit, String(""), true, Visibility.PUBLIC),
Variable.Basic("aValue2", Type.Implicit, String(""), true, Visibility.PUBLIC),
))
Expand Down Expand Up @@ -426,8 +457,8 @@ internal class CrescentParserTests {
name = "add",
modifiers = emptyList(),
visibility = Visibility.PUBLIC,
params = listOf(Parameter.Basic("value1", Type.Basic("Int")), Parameter.Basic("value2", Type.Basic("Int"))),
returnType = Type.Basic("Int"),
params = listOf(Parameter.Basic("value1", Type.Basic("I32")), Parameter.Basic("value2", Type.Basic("I32"))),
returnType = Type.Basic("I32"),
innerCode = Statement.Block(listOf(
Return(Expression(listOf(Identifier("value1"), Identifier("value2"), ADD)))
))
Expand All @@ -436,8 +467,8 @@ internal class CrescentParserTests {
name = "sub",
modifiers = emptyList(),
visibility = Visibility.PUBLIC,
params = listOf(Parameter.Basic("value1", Type.Basic("Int")), Parameter.Basic("value2", Type.Basic("Int"))),
returnType = Type.Basic("Int"),
params = listOf(Parameter.Basic("value1", Type.Basic("I32")), Parameter.Basic("value2", Type.Basic("I32"))),
returnType = Type.Basic("I32"),
innerCode = Statement.Block(listOf(
Return(Expression(listOf(Identifier("value1"), Identifier("value2"), SUB)))
))
Expand Down
Loading

0 comments on commit 347c746

Please sign in to comment.