Commit dd322384 authored by Arthur Bit-Monnot's avatar Arthur Bit-Monnot

[anml] Generalize expression parsing (wip)

parent 920e6fb1
package copla.lang.model.common
object operators {
type TypingResult = Either[String, Type]
sealed abstract class Associativity
object Associativity {
case object Left extends Associativity
case object Right extends Associativity
case object Non extends Associativity
}
private implicit class TypeOps(private val tpe: Type) extends AnyVal {
def isNum: Boolean = tpe.isSubtypeOf(Type.Numeric)
def isInt: Boolean = tpe.isSubtypeOf(Type.Integer)
def isFloat: Boolean = tpe.isSubtypeOf(Type.Float)
}
sealed trait Operator {
def op: String
def precedence: Int
}
sealed abstract class UnaryOperator(val op: String, val precedence: Int) extends Operator {
def tpe(lhs: Type): TypingResult
}
sealed abstract class BinaryOperator(val op: String,
val precedence: Int,
val associativity: Associativity)
extends Operator {
def tpe(lhs: Type, rhs: Type): TypingResult
}
sealed abstract class BooleanBinaryOperator(op: String,
precedence: Int,
associativity: Associativity)
extends BinaryOperator(op, precedence, associativity) {
override def tpe(lhs: Type, rhs: Type): TypingResult = (lhs, rhs) match {
case (Type.Boolean, Type.Boolean) => Right(Type.Boolean)
case _ => Left("On of the subexpression does not have the boolean type.")
}
}
sealed abstract class NumericBinaryOperator(op: String, precedence: Int)
extends BinaryOperator(op, precedence, Associativity.Left) {
override def tpe(lhs: Type, rhs: Type): TypingResult = {
lhs.lowestCommonAncestor(rhs) match {
case None if !lhs.isNum => Left(s"Left hand side is not a numeric type but: $lhs")
case None if !rhs.isNum => Left(s"Right hand side is not a numeric type but: $rhs")
case Some(Type.Numeric) =>
Left(s"Left and right hand side have incompatible numeric types : ($lhs, $rhs)")
case Some(tpe) if tpe.isInt => Right(Type.Integer)
case Some(tpe) if tpe.isFloat => Right(Type.Float)
case x => sys.error(s"Unhandled case: $x")
}
}
}
sealed abstract class NumericComparison(op: String, precedence: Int)
extends BinaryOperator(op, precedence, Associativity.Non) {
override def tpe(lhs: Type, rhs: Type): TypingResult = {
lhs.lowestCommonAncestor(rhs) match {
case None if !lhs.isNum => Left(s"Left hand side is not a numeric type but: $lhs")
case None if !rhs.isNum => Left(s"Right hand side is not a numeric type but: $rhs")
case Some(Type.Numeric) =>
Left(s"Left and right hand side have incompatible numeric types : ($lhs, $rhs)")
case Some(tpe) if tpe.isInt || tpe.isFloat => Right(Type.Boolean)
case x => sys.error(s"Unhandled case: $x")
}
}
}
sealed abstract class EqualityOperator(op: String, precedence: Int)
extends BinaryOperator(op, precedence, Associativity.Non) {
override def tpe(lhs: Type, rhs: Type): TypingResult = {
if(!lhs.overlaps(rhs))
Left(s"Comparing unrelated types: $lhs $rhs")
else
Right(Type.Boolean)
}
}
case object Implies extends BooleanBinaryOperator("implies", 1, Associativity.Right)
case object Xor extends BooleanBinaryOperator("xor", 3, Associativity.Left)
case object Or extends BooleanBinaryOperator("or", 4, Associativity.Left)
case object And extends BooleanBinaryOperator("and", 5, Associativity.Left)
case object Eq extends EqualityOperator("==", 7)
case object Neq extends EqualityOperator("!=", 7)
case object LT extends NumericComparison("<", 7)
case object GT extends NumericComparison(">", 7)
case object LEQ extends NumericComparison("<=", 7)
case object GEQ extends NumericComparison(">=", 7)
case object Add extends NumericBinaryOperator("+", 13)
case object Sub extends NumericBinaryOperator("-", 13)
case object Mul extends NumericBinaryOperator("*", 14)
case object Div extends NumericBinaryOperator("/", 14)
case object Minus extends UnaryOperator("-", 15) {
override def tpe(lhs: Type): TypingResult =
if(lhs.isInt || lhs.isFloat) Right(lhs)
else Left(s"Non numeric type $lhs")
}
val all: Set[Operator] =
Set(Implies, Xor, Or, And, Eq, Neq, LT, GT, LEQ, GEQ, Add, Sub, Mul, Div, Minus)
val layeredOps = computeLayeredOps(all)
private def computeLayeredOps(ops: Set[Operator]): Seq[OperatorGroup] = {
all
.groupBy(_.precedence)
.values
.map(opSet => {
opSet.head match {
case bin: BinaryOperator =>
BinOperatorGroup(opSet.map(_.asInstanceOf[BinaryOperator]),
bin.precedence,
bin.associativity)
case uni: UnaryOperator =>
UniOperatorGroup(opSet.map(_.asInstanceOf[UnaryOperator]), uni.precedence)
}
})
.toSeq
.sortBy(_.precedence)
.reverse
}
sealed trait OperatorGroup {
def precedence: Int
}
case class BinOperatorGroup(ops: Set[BinaryOperator],
precedence: Int,
associativity: Associativity)
extends OperatorGroup {
require(ops.forall(op => op.precedence == precedence && op.associativity == associativity))
}
case class UniOperatorGroup(ops: Set[UnaryOperator], precedence: Int) extends OperatorGroup {
require(ops.forall(op => op.precedence == precedence))
}
}
......@@ -2,6 +2,8 @@ package copla.lang.model
import copla.lang.model
import scala.annotation.tailrec
package object common {
case class Id(scope: Scope, name: String) {
......@@ -38,14 +40,29 @@ package object common {
def isSubtypeOf(typ: Type): Boolean =
this == typ || parent.exists(t => t == typ || t.isSubtypeOf(typ))
def overlaps(typ: Type): Boolean = this.isSubtypeOf(typ) || typ.isSubtypeOf(this)
def overlaps(typ: Type): Boolean =
this.isSubtypeOf(typ) || typ.isSubtypeOf(this)
def lowestCommonAncestor(typ: Type): Option[Type] =
if(this.isSubtypeOf(typ))
Some(typ)
else if(typ.isSubtypeOf(this))
Some(this)
else
parent match {
case Some(father) => father.lowestCommonAncestor(typ)
case None => None
}
def asScope: Scope = id.scope + id.name
override def toString: String = id.toString
}
object Type {
val Integers = Type(Id(RootScope, "integer"), None)
val Numeric = Type(Id(RootScope, "__numeric__"), None)
val Integer = Type(Id(RootScope, "integer"), Some(Numeric))
val Float = Type(Id(RootScope, "float"), Some(Numeric))
val Boolean = Type(Id(RootScope, "boolean"), None)
}
sealed trait Term {
......@@ -60,7 +77,7 @@ package object common {
sealed trait Var extends Term
case class IntLiteral(value: Int) extends Cst {
override def typ: Type = Type.Integers
override def typ: Type = Type.Integer
override def id: Id = Id(RootScope + "_integers_", value.toString)
override def toString: String = value.toString
}
......
......@@ -146,7 +146,7 @@ package object core {
def apply(lit: Int): IntExpr = IntTerm(IntLiteral(lit))
}
final case class IntTerm(e: Term) extends IntExpr {
require(e.typ.isSubtypeOf(Type.Integers))
require(e.typ.isSubtypeOf(Type.Integer))
override def toString: String = e.toString
}
......
package copla.lang.model
import copla.lang.model.common._
import copla.lang.model.common.operators.{BinaryOperator, UnaryOperator}
package object full {
......@@ -61,6 +62,25 @@ package object full {
sealed trait StaticExpr extends Expr
sealed trait TimedExpr extends Expr
sealed trait ExprTree extends StaticExpr
case class BinaryExprTree(op: BinaryOperator, lhs: StaticExpr, rhs: StaticExpr) extends ExprTree {
override val typ: Type = op.tpe(lhs.typ, rhs.typ) match {
case Right(tpe) => tpe
case Left(err) =>
sys.error(err)
}
override def toString: String = s"(${op.op} $lhs $rhs)"
}
case class UnaryExprTree(op: UnaryOperator, lhs: StaticExpr) extends ExprTree {
override val typ: Type = op.tpe(lhs.typ) match {
case Right(tpe) => tpe
case Left(err) => sys.error(err)
}
override def toString: String = s"(${op.op} $lhs)"
}
sealed trait CommonTerm extends StaticExpr
object CommonTerm {
def apply(v: Term): CommonTerm = v match {
......@@ -71,9 +91,11 @@ package object full {
case class Variable(v: Var) extends CommonTerm {
override def typ: Type = v.typ
override def toString: String = v.toString
}
case class ConstantExpr(term: Cst) extends CommonTerm {
override def typ: Type = term.typ
override def toString: String = term.toString
}
sealed trait IntExpr
......@@ -200,9 +222,16 @@ package object full {
}
trait StaticAssertion extends Statement
case class BooleanAssertion(expr: StaticExpr) extends StaticAssertion {
require(expr.typ.isSubtypeOf(Type.Boolean))
override def toString: String = expr.toString
}
@deprecated
case class StaticEqualAssertion(left: StaticExpr, right: StaticExpr) extends StaticAssertion {
override def toString: String = s"$left == $right"
}
@deprecated
case class StaticDifferentAssertion(left: StaticExpr, right: StaticExpr) extends StaticAssertion {
override def toString: String = s"$left != $right"
}
......
......@@ -9,6 +9,14 @@ import ParserApi.whiteApi._
import ParserApi.extendedApi._
import fastparse.core.Parsed.Failure
import copla.lang.model.common._
import copla.lang.model.common.operators.{
Associativity,
BinOperatorGroup,
BinaryOperator,
OperatorGroup,
UnaryOperator,
UniOperatorGroup
}
import copla.lang.model.full._
import scala.annotation.tailrec
......@@ -222,12 +230,13 @@ abstract class AnmlParser(val initialContext: Ctx) {
}
}
val staticExpr: Parser[StaticExpr] = {
val staticTerm: Parser[StaticExpr] = {
val partiallyAppliedConstant = partiallyAppliedFunction
.namedFilter(_._1.isInstanceOf[ConstantTemplate], "is-constant")
.map(tup => (tup._1.asInstanceOf[ConstantTemplate], tup._2))
variable.map(CommonTerm(_)) |
int.map(i => CommonTerm(IntLiteral(i))) |
variable.map(CommonTerm(_)) |
(constantFunc ~/ Pass).flatMap(f =>
f.params.map(param => param.typ) match {
case Seq() => (("(" ~/ ")") | Pass) ~ PassWith(Constant(f, Seq()))
......@@ -243,8 +252,79 @@ abstract class AnmlParser(val initialContext: Ctx) {
"(" ~/ varList(paramTypes.tail, ",")
.map(args => Constant(f, CommonTerm(firstArg) +: args)) ~ ")" ~/ Pass
}
} |
int.map(i => CommonTerm(IntLiteral(i)))
}
}
val staticExpr: P[StaticExpr] = Tmp.expr
object Tmp {
type E = StaticExpr
type PE = Parser[StaticExpr]
def term: PE = P(staticTerm ~/ Pass)
def expr: PE = P(top)
val bottom: PE = P(("(" ~/ expr ~/ ")") | term)
val top: PE =
operators.layeredOps.foldLeft(bottom) {
case (inner, opGroup) => groupParser(opGroup, inner)
}
def binGroupParser(gpe: BinOperatorGroup, inner: PE): PE = {
// parser for a single operator in the group
val operator: P[BinaryOperator] =
StringIn(gpe.ops.map(_.op).toSeq: _*).!
.optGet(str => gpe.ops.find(_.op == str))
.opaque(gpe.ops.map(a => "\""+a.op+"\"").mkString("(", "|",")"))
gpe match {
case BinOperatorGroup(ops, _, Associativity.Left) =>
(inner ~/ (operator ~/ inner).rep).optGet({
case (head, tail) =>
tail.foldLeft(Option(head)) {
case (acc, (op, rhs)) => acc.flatMap(x => asBinary(op, x, rhs))
}
}, "well-typed")
case BinOperatorGroup(ops, _, Associativity.Right) =>
(inner ~/ (operator ~/ inner).rep).optGet({
case (head, tail) =>
def makeRightAssociative[A,B](e1: A, nexts: List[(B, A)]): (List[(A,B)], A) = nexts match {
case Nil => (Nil, e1)
case (b, e2) :: rest =>
val (prevs, last) = makeRightAssociative(e2, rest)
((e1, b) :: prevs, last)
}
val (prevs: List[(StaticExpr, BinaryOperator)], last: StaticExpr) = makeRightAssociative(head, tail.toList)
prevs.foldRight(Option(last)) { case ((lhs, op), rhs) => rhs.flatMap(asBinary(op, lhs, _)) }
}, "well-typed")
case BinOperatorGroup(ops, _, Associativity.Non) =>
(inner ~/ (operator ~/ inner).?).optGet({
case (lhs, None) => Some(lhs)
case (lhs, Some((op, rhs))) => asBinary(op, lhs, rhs)
}, "well-typed")
}
}
def unaryGroupParser(gpe: UniOperatorGroup, inner: PE): PE = {
val operator: P[UnaryOperator] =
StringIn(gpe.ops.map(_.op).toSeq: _*).!
.optGet(str => gpe.ops.find(_.op == str))
.opaque(gpe.ops.map(a => "\""+a.op+"\"").mkString("(", "|",")"))
(operator.? ~ inner).sideEffect(println).optGet({
case (None, e) => Some(e)
case (Some(op), e) => Try(full.UnaryExprTree(op, e)).toOption
}, "well-typed")
}
def groupParser(gpe: OperatorGroup, inner: PE): PE = gpe match {
case x: BinOperatorGroup => binGroupParser(x, inner)
case x: UniOperatorGroup => unaryGroupParser(x, inner)
}
def asBinary(op: BinaryOperator, lhs: StaticExpr, rhs: StaticExpr): Option[StaticExpr] = {
op.tpe(lhs.typ, rhs.typ) match {
case Right(_) => Some(full.BinaryExprTree(op, lhs, rhs))
case Left(_) => None
}
}
}
object IntOperators {
......@@ -252,7 +332,7 @@ abstract class AnmlParser(val initialContext: Ctx) {
def primary: Parser[IntExpr] = P {
P("(").flatMap(_ => additiveExpr ~ ")") |
staticExpr.namedFilter(_.typ == Type.Integers, "of-type-integer").map(full.GenIntExpr) |
staticExpr.namedFilter(_.typ == Type.Integer, "of-type-integer").map(full.GenIntExpr) |
int.map(i => GenIntExpr(ConstantExpr(IntLiteral(i)))) |
"-" ~/ primary
}
......@@ -345,18 +425,15 @@ abstract class AnmlParser(val initialContext: Ctx) {
val staticAssertion: Parser[StaticAssertion] = {
var leftExpr: StaticExpr = null
(staticExpr.sideEffect(leftExpr = _) ~/
(("==" | "!=" | ":=").! ~/
staticExpr.namedFilter(_.typ.overlaps(leftExpr.typ), "has-compatible-type")).? ~
(":=".! ~/ staticExpr.namedFilter(_.typ.overlaps(leftExpr.typ), "has-compatible-type")).? ~
";")
.namedFilter({
case (_, Some(_)) => true
case (expr, None) => expr.typ.id.name == "boolean"
case (expr, None) => expr.typ.isSubtypeOf(Type.Boolean)
}, "boolean-if-no-right-side")
.map {
case (left, Some(("==", right))) => StaticEqualAssertion(left, right)
case (left, Some(("!=", right))) => StaticDifferentAssertion(left, right)
case (left, Some((":=", right))) => StaticAssignmentAssertion(left, right)
case (expr, None) => StaticEqualAssertion(expr, CommonTerm(ctx.findVariable("true").get))
case (expr, None) => BooleanAssertion(expr)
case _ => sys.error("Something is wrong with this parser.")
}
}
......@@ -452,10 +529,10 @@ class AnmlModuleParser(val initialModel: Model) extends AnmlParser(initialModel)
instancesDeclaration |
functionDeclaration.map(Seq(_)) |
timepointDeclaration.map(Seq(_)) |
temporalConstraint |
temporallyQualifiedAssertion |
staticAssertion.map(Seq(_)) |
action.map(Seq(_))
// temporalConstraint |
// temporallyQualifiedAssertion |
staticAssertion.map(Seq(_)) //|
// action.map(Seq(_))
def currentModel: Model = ctx match {
case m: Model => m
......@@ -554,20 +631,17 @@ class AnmlTypeParser(val initialModel: Model) extends AnmlParser(initialModel) {
object Parser {
private val anmlHeader =
"""|type boolean;
|instance boolean true, false;
|type integer;
|timepoint start;
|timepoint end;
""".stripMargin
/** ANML model with default definitions already added */
val baseAnmlModel: Model =
parse(anmlHeader, Some(new Model())) match {
case ParseSuccess(model) => model
case err: ParseFailure => sys.error("Could not parse the ANML headed:\n" + err.format)
}
(Model() ++ Seq(
TypeDeclaration(Type.Boolean),
TypeDeclaration(Type.Numeric),
TypeDeclaration(Type.Integer),
InstanceDeclaration(Instance(Id(RootScope, "true"), Type.Boolean)),
InstanceDeclaration(Instance(Id(RootScope, "false"), Type.Boolean)),
TimepointDeclaration(Timepoint(Id(RootScope, "start"))),
TimepointDeclaration(Timepoint(Id(RootScope, "end"))),
)).getOrElse(sys.error("Could not instantiate base model"))
/** Parses an ANML string. If the previous model parameter is Some(m), then the result
* of parsing will be appended to m.
......
......@@ -12,6 +12,8 @@ class AnmlParsingTest extends FunSuite {
println("AS:\n" + module + "\n\n")
case x: ParseFailure =>
fail(s"Could not parse anml string: $anml\n\n${x.format}")
case UnidentifiedError(err, _) => err.printStackTrace()
fail(s"Exception raised while parsing")
}
}
}
......@@ -27,12 +29,20 @@ class AnmlParsingTest extends FunSuite {
}
}
val tmp = "type A with { fluent boolean x; }; type B with { fluent boolean x; };"
val tmp = "constant integer i; i == 2 implies (-i + 16) * 2 == -4 implies true;"
test("debug: temporary") {
/** Dummy text to facilitate testing. */
println(tmp)
println(Parser.parse(tmp))
Parser.parse(tmp) match {
case ParseSuccess(module) =>
println("PARSED:\n")
println("AS:\n" + module + "\n\n")
case x: ParseFailure =>
fail(s"Could not parse anml string:\n${x.format}")
case UnidentifiedError(err, _) => err.printStackTrace()
fail(s"Exception raised while parsing")
}
}
}
......@@ -57,7 +57,7 @@ case class ProblemContext(intTag: BoxedInt[Literal],
ie match {
case IntTerm(IntLiteral(d)) => Cst(d)
case IntTerm(v: Var) =>
assert(v.typ == Type.Integers)
assert(v.typ == Type.Integer)
val variable = encode(v)
variable.typ match {
case tpe: BoxedInt[Literal] => variable.unboxed(tpe)
......@@ -167,7 +167,7 @@ case class ProblemContext(intTag: BoxedInt[Literal],
object ProblemContext {
def extract(m: Seq[InModuleBlock]): ProblemContext = {
val objectTypes = m.collect { case TypeDeclaration(t) if t != Type.Integers => t }
val objectTypes = m.collect { case TypeDeclaration(t) if t != Type.Integer => t }
val objectSubtypes = mutable.LinkedHashMap[Type, mutable.Set[Type]]()
val instances = m
.collect { case InstanceDeclaration(i) => i }
......@@ -203,7 +203,7 @@ object ProblemContext {
assert(toIndex.size == fromIndex.size)
def tagOf(t: Type): TagIsoInt[ObjLit] = {
assert(t != Type.Integers)
assert(t != Type.Integer)
def instancesOf(t: Type): Seq[ObjLit] =
instances.getOrElse(t, Seq()) ++ objectSubtypes(t).flatMap(instancesOf)
......@@ -276,7 +276,7 @@ object ProblemContext {
val memo = mutable.Map[Type, TagIsoInt[ObjLit]]()
val specializedTag = (t: Type) =>
if(t == Type.Integers)
if(t == Type.Integer)
intTag
else
memo.getOrElseUpdate(t, tagOf(t))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment