ShintenScript/ASTNodeComparison.cs

135 lines
5.0 KiB
C#

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace ShintenScript
{
public class ASTNodeComparison : IASTNode
{
public readonly IList<IASTNode> nodes;
public readonly IList<Token.Type> comparisons;
public ASTNodeComparison(IList<IASTNode> nodes, IList<Token.Type> comparisons)
{
this.nodes = nodes;
this.comparisons = comparisons;
}
public bool Const()
{
return nodes.All(node => node.Const());
}
public object CreateFunction()
{
SSType argType = nodes[0].Type();
if (comparisons.Count == 1)
{
if (argType == SSType.boolean)
{
Func<ExecutionContext, bool> left = (Func<ExecutionContext, bool>)nodes[0].CreateFunction();
Func<ExecutionContext, bool> right = (Func<ExecutionContext, bool>)nodes[1].CreateFunction();
switch (comparisons[0])
{
case Token.Type.EQ: return (Func<ExecutionContext, bool>)(ctx => left(ctx) == right(ctx));
case Token.Type.NE: return (Func<ExecutionContext, bool>)(ctx => left(ctx) != right(ctx));
}
}
if (argType == SSType.real)
{
Func<ExecutionContext, float> left = (Func<ExecutionContext, float>)nodes[0].CreateFunction();
Func<ExecutionContext, float> right = (Func<ExecutionContext, float>)nodes[1].CreateFunction();
switch (comparisons[0])
{
case Token.Type.EQ: return (Func<ExecutionContext, bool>)(ctx => left(ctx) == right(ctx));
case Token.Type.NE: return (Func<ExecutionContext, bool>)(ctx => left(ctx) != right(ctx));
case Token.Type.LT: return (Func<ExecutionContext, bool>)(ctx => left(ctx) < right(ctx));
case Token.Type.GT: return (Func<ExecutionContext, bool>)(ctx => left(ctx) > right(ctx));
case Token.Type.LE: return (Func<ExecutionContext, bool>)(ctx => left(ctx) <= right(ctx));
case Token.Type.GE: return (Func<ExecutionContext, bool>)(ctx => left(ctx) >= right(ctx));
}
}
throw new Exception("This should be unreachable");
}
if (argType == SSType.boolean)
{
Func<ExecutionContext, bool>[] args = nodes.Select(node => (Func<ExecutionContext, bool>)node.CreateFunction()).ToArray();
return (Func<ExecutionContext, bool>)(ctx =>
{
bool against = args[0](ctx);
for (int i = 1; i < args.Length; i++)
{
bool vs = args[i](ctx);
if ((against == vs) ^ (comparisons[i - 1] == Token.Type.EQ))
return false;
against = vs;
}
return true;
});
}
if (argType == SSType.real)
{
Func<ExecutionContext, float>[] args = nodes.Select(node => (Func<ExecutionContext, float>)node.CreateFunction()).ToArray();
return (Func<ExecutionContext, bool>)(ctx =>
{
float against = args[0](ctx);
for (int i = 1; i < args.Length; i++)
{
float vs = args[i](ctx);
switch (comparisons[i - 1])
{
case Token.Type.EQ: if (vs != against) return false; break;
case Token.Type.NE: if (vs == against) return false; break;
case Token.Type.GT: if (vs >= against) return false; break;
case Token.Type.LT: if (vs <= against) return false; break;
case Token.Type.GE: if (vs > against) return false; break;
case Token.Type.LE: if (vs < against) return false; break;
}
against = vs;
}
return true;
});
}
throw new Exception("This should be unreachable");
}
public SSType Type()
{
bool ord = !comparisons.All(cmp => cmp == Token.Type.EQ || cmp == Token.Type.NE);
SSType argType = nodes[0].Type();
if (argType == SSType.none)
throw new TypeException("Cannot compare null");
if (nodes.Skip(1).Any(node => node.Type() != argType))
throw new TypeException("Cannot compare different types");
if (ord && argType != SSType.real)
throw new TypeException($"Type {argType} cannot be ordered");
return SSType.boolean;
}
}
}