Commit 7f0d756b authored by Arthur Bit-Monnot's avatar Arthur Bit-Monnot

[anml] Refactor type system.

parent e9b19182
......@@ -12,9 +12,9 @@ object operators {
}
private implicit class TypeOps(private val tpe: Type) extends AnyVal {
def isNum: Boolean = tpe.isSubtypeOf(Type.Numeric)
def isInt: Boolean = tpe.isSubtypeOf(Type.Integer)
def isFloat: Boolean = tpe.isSubtypeOf(Type.Float)
def isNum: Boolean = tpe.isSubtypeOf(Type.Reals)
def isInt: Boolean = tpe.isSubtypeOf(Type.Integers)
def isFloat: Boolean = isNum && !isInt
}
sealed trait Operator {
......@@ -46,12 +46,10 @@ object operators {
extends BinaryOperator(op, precedence, Associativity.Left) {
override def tpe(lhs: Type, rhs: Type): TypingResult = {
lhs.lowestCommonAncestor(rhs) match {
case None if !lhs.isNum => Left(s"Left hand side is not a numeric type but: $lhs")
case None if !rhs.isNum => Left(s"Right hand side is not a numeric type but: $rhs")
case Some(Type.Numeric) =>
Left(s"Left and right hand side have incompatible numeric types : ($lhs, $rhs)")
case Some(tpe) if tpe.isInt => Right(Type.Integer)
case Some(tpe) if tpe.isFloat => Right(Type.Float)
case None if !lhs.isNum => Left(s"Left hand side is not a numeric type but: $lhs")
case None if !rhs.isNum => Left(s"Right hand side is not a numeric type but: $rhs")
case Some(tpe) if tpe.isInt => Right(Type.Integers)
case Some(tpe) if tpe.isFloat => Right(Type.Reals)
case x => sys.error(s"Unhandled case: $x")
}
......@@ -62,12 +60,10 @@ object operators {
extends BinaryOperator(op, precedence, Associativity.Non) {
override def tpe(lhs: Type, rhs: Type): TypingResult = {
lhs.lowestCommonAncestor(rhs) match {
case None if !lhs.isNum => Left(s"Left hand side is not a numeric type but: $lhs")
case None if !rhs.isNum => Left(s"Right hand side is not a numeric type but: $rhs")
case Some(Type.Numeric) =>
Left(s"Left and right hand side have incompatible numeric types : ($lhs, $rhs)")
case Some(tpe) if tpe.isInt || tpe.isFloat => Right(Type.Boolean)
case x => sys.error(s"Unhandled case: $x")
case None if !lhs.isNum => Left(s"Left hand side is not a numeric type but: $lhs")
case None if !rhs.isNum => Left(s"Right hand side is not a numeric type but: $rhs")
case Some(tpe) if tpe.isNum => Right(Type.Boolean)
case x => sys.error(s"Unhandled case: $x")
}
}
}
......
......@@ -2,6 +2,8 @@ package copla.lang.model
import copla.lang.model
import copla.lang.model.common.operators.{BinaryOperator, UnaryOperator}
import spire.math.Real
import spire.implicits._
package object common {
......@@ -35,7 +37,10 @@ package object common {
def toScopedString(name: String): String = s"$this.$name"
}
final case class Type(id: Id, parent: Option[Type]) {
sealed trait Type {
def id: Id
def parent: Option[Type]
def isSubtypeOf(typ: Type): Boolean =
this == typ || parent.exists(t => t == typ || t.isSubtypeOf(typ))
......@@ -58,26 +63,84 @@ package object common {
override def toString: String = id.toString
}
object Type {
val Numeric = Type(Id(RootScope, "__numeric__"), None)
val Integer = Type(Id(RootScope, "integer"), Some(Numeric))
val Time = Integer //Type(Id(RootScope, "time"), Some(Integer))
val Float = Type(Id(RootScope, "float"), Some(Numeric))
val Boolean = Type(Id(RootScope, "boolean"), None)
sealed trait ObjType extends Type {
def id: Id
def parent: Option[ObjType]
}
object ObjectTop extends ObjType {
override def id: Id = Id(RootScope, "__object__")
override def parent: Option[Nothing] = None
}
final case class ObjSubType(id: Id, father: ObjType) extends ObjType {
def parent: Some[ObjType] = Some(father)
}
sealed trait IRealType extends Type {
def min: Option[Real]
def max: Option[Real]
}
object Reals extends IRealType {
def id: Id = Id(RootScope, "float")
override def parent: Option[Nothing] = None
override def min: Option[Real] = None
override def max: Option[Real] = None
}
final case class RealSubType(id: Id,
father: IRealType,
min: Option[Real] = None,
max: Option[Real] = None)
extends IRealType {
require(min.forall(m => father.min.forall(_ <= m)))
require(max.forall(m => father.max.forall(m >= _)))
override def parent: Some[IRealType] = Some(father)
}
sealed trait IIntType extends IRealType {
def intMin: Option[BigInt]
def intMax: Option[BigInt]
override def min: Option[Real] = intMin.map(Real(_))
override def max: Option[Real] = intMax.map(Real(_))
}
object Integers extends IIntType {
override def id: Id = Id(RootScope, "integer")
override def parent: Some[Reals.type] = Some(Reals)
override def intMin: Option[BigInt] = None
override def intMax: Option[BigInt] = None
}
final case class IntSubType(id: Id,
father: IIntType,
intMin: Option[BigInt] = None,
intMax: Option[BigInt] = None)
extends IIntType {
require(intMin.forall(min => father.intMin.forall(_ <= min)))
require(intMax.forall(max => father.intMax.forall(max >= _)))
override def parent: Some[IIntType] = Some(father)
}
val Time = IntSubType(Id(RootScope, "time"), Integers)
val Boolean = ObjSubType(Id(RootScope, "boolean"), ObjectTop)
// val Numeric = Type(Id(RootScope, "__numeric__"), None)
// val Integer = Type(Id(RootScope, "integer"), Some(Numeric))
// val Time = Integer //Type(Id(RootScope, "time"), Some(Integer))
// val Float = Type(Id(RootScope, "float"), Some(Numeric))
// val Boolean = Type(Id(RootScope, "boolean"), None)
val True = Instance(Id(RootScope, "true"), Boolean)
val False = Instance(Id(RootScope, "false"), Boolean)
sealed trait Top
sealed trait Obj extends Top
sealed trait Real extends Top
sealed trait Int extends Real
type Time = Int
trait Tpe[+T]
case class TObj(id: Id, parent: Option[Type]) extends Tpe[Obj]
// case class TReal(id: Id, parent: Option[TReal]) extends Tpe[Real]
case class TInt(id: Id, parent: Option[TInt], min: Option[Int], max: Option[Int])
extends Tpe[Int]
val Start = Id(RootScope, "start")
val End = Id(RootScope, "end")
}
sealed trait Expr {
def typ: Type
......@@ -93,7 +156,7 @@ package object common {
sealed trait Var extends Term
case class IntLiteral(value: Int) extends Cst {
override def typ: Type = Type.Integer
override def typ: Type = Type.Integers
override def id: Id = Id(RootScope + "_integers_", value.toString)
override def toString: String = value.toString
}
......
......@@ -599,7 +599,10 @@ class AnmlTypeParser(val initialModel: Model) extends AnmlParser(initialModel) {
(word | int | CharIn("{}[]();=:<>-+.,!/*")).!.namedFilter(_ != "type", "non-type-token")
val typeDeclaration: Parser[TypeDeclaration] =
(typeKW ~/ freeIdent ~ ("<" ~ declaredType).? ~ (";" | withKW)).map {
case (name, parentOpt) => new TypeDeclaration(new Type(ctx.id(name), parentOpt))
case (name, None) => TypeDeclaration(Type.ObjSubType(ctx.id(name), Type.ObjectTop))
case (name, Some(t: Type.ObjType)) => TypeDeclaration(Type.ObjSubType(ctx.id(name), t))
case (name, Some(t: Type.IIntType)) => TypeDeclaration(Type.IntSubType(ctx.id(name), t))
case (name, Some(t: Type.IRealType)) => TypeDeclaration(Type.RealSubType(ctx.id(name), t))
}
private[this] def currentModel: Model = ctx match {
......@@ -623,13 +626,14 @@ object Parser {
/** ANML model with default definitions already added */
val baseAnmlModel: Model =
(Model() ++ Seq(
TypeDeclaration(Type.ObjectTop),
TypeDeclaration(Type.Boolean),
TypeDeclaration(Type.Numeric),
TypeDeclaration(Type.Integer),
TypeDeclaration(Type.Reals),
TypeDeclaration(Type.Integers),
InstanceDeclaration(Type.True),
InstanceDeclaration(Type.False),
TimepointDeclaration(Id(RootScope, "start")),
TimepointDeclaration(Id(RootScope, "end")),
TimepointDeclaration(Type.Start),
TimepointDeclaration(Type.End),
)).getOrElse(sys.error("Could not instantiate base model"))
/** Parses an ANML string. If the previous model parameter is Some(m), then the result
......
......@@ -71,6 +71,7 @@ lazy val anml = project
"com.lihaoyi" %% "fastparse" % "1.0.0",
"com.github.scopt" %% "scopt" % "3.7.0",
"com.chuusai" %% "shapeless" % "2.3.3",
"org.typelevel" %% "spire" % "0.14.1",
"org.scalatest" %% "scalatest" % "3.0.5" % "test"
))
......
......@@ -42,7 +42,8 @@ object anml extends Module {
def ivyDeps = Agg(
ivy"com.lihaoyi::fastparse:1.0.0",
ivy"com.github.scopt::scopt:3.7.0",
ivy"com.chuusai::shapeless:2.3.3"
ivy"com.chuusai::shapeless:2.3.3",
ivy"org.typelevel::spire:0.14.1"
)
object tests extends Tests {
......
......@@ -8,8 +8,6 @@ import dahu.model.types._
import scala.reflect.ClassTag
final case class ConstraintViolated(constraint: Tentative[Boolean])
/** Evaluation yields an Either[ConstraintViolated, T] */
sealed trait Tentative[T] {
def typ: Tag[T]
......
......@@ -99,9 +99,8 @@ object Interpreter {
}
}
}
case class ConstraintViolated(nodes: Seq[Any]) extends Result[Nothing] {
override def toString: String = "ConstraintViolated"
}
case class ConstraintViolated(nodes: Seq[Any]) extends Result[Nothing]
case class Res[T](v: T) extends Result[T]
case object Empty extends Result[Nothing]
......
......@@ -2,6 +2,7 @@ package dahu.model.problem
import cats.Functor
import cats.implicits._
import dahu.model.compiler.Algebras
import dahu.model.functions.Reversible
import dahu.utils._
import dahu.model.ir._
......@@ -44,11 +45,11 @@ object SatisfactionProblem {
object ElimReversible extends Optimizer {
override def optim(retrieve: ID => Total[ID],
record: Total[ID] => ID): Total[ID] => Total[ID] = {
case ComputationF(f: Reversible[_, _], Vec1(arg), _) =>
case orig @ ComputationF(f: Reversible[_, _], Vec1(arg), _) =>
retrieve(arg) match {
case ComputationF(f2: Reversible[_, _], Vec1(arg2), _) if f2 == f.reverse =>
retrieve(arg2)
case x => x
case _ => orig
}
case x => x
}
......@@ -172,10 +173,11 @@ object SatisfactionProblem {
ElimEmptyAndSingletonMonoid,
FlattenMonoids,
ElimDuplicationsIdempotentMonoids,
ElimTautologies,
ConstantFolding,
OrderArgs,
)
val optimizer: Optimizer = optimizers.foldLeft[Optimizer](NoOpOptimizer) {
val optimizer: Optimizer = (optimizers ++ optimizers).foldLeft[Optimizer](NoOpOptimizer) {
case (acc, next) => acc.andThen(next)
}
}
......@@ -267,51 +269,16 @@ object SatisfactionProblem {
val TRUE: ID = rec(CstF(Value(true), Tag.ofBoolean))
val FALSE: ID = rec(CstF(Value(false), Tag.ofBoolean))
def and(conjuncts: ID*): ID = {
if(conjuncts.contains(FALSE)) {
FALSE
} else {
val reduced = conjuncts.distinct.filter(_ != TRUE).sorted
if(reduced.isEmpty)
TRUE
else if(reduced.size == 1)
reduced.head
else
rec(ComputationF(bool.And, reduced, Tag.ofBoolean))
}
}
def and(conjuncts: Vec[ID]): ID = {
if(conjuncts.contains(FALSE)) {
FALSE
} else {
val reduced = conjuncts.distinct.filter(_ != TRUE).sorted
if(reduced.isEmpty)
TRUE
else if(reduced.size == 1)
reduced(0)
else
rec(ComputationF(bool.And, reduced, Tag.ofBoolean))
}
}
def or(disjuncts: ID*): ID = {
if(disjuncts.contains(TRUE)) {
TRUE
} else {
val reduced = disjuncts.distinct.filter(_ != FALSE).sorted
if(reduced.isEmpty)
FALSE
else if(reduced.size == 1)
reduced.head
else
rec(ComputationF(bool.Or, reduced, Tag.ofBoolean))
}
}
def not(e: ID): ID = {
def and(conjuncts: ID*): ID = and(Vec.unsafe(conjuncts.toArray))
def and(conjuncts: Vec[ID]): ID = rec(ComputationF(bool.And, conjuncts, Tag.ofBoolean))
def or(disjuncts: ID*): ID =
rec(ComputationF(bool.Or, Vec.unsafe(disjuncts.toArray), Tag.ofBoolean))
def not(e: ID): ID =
rec(ComputationF(bool.Not, Seq(e), Tag.ofBoolean))
}
def implies(cond: ID, eff: ID): ID = {
def implies(cond: ID, eff: ID): ID =
or(not(cond), eff)
}
}
class LazyTreeSpec[@specialized(Int) K](f: K => ExprF[K], g: Context => ExprF[IR[ID]] => IR[ID]) {
......@@ -421,7 +388,6 @@ object SatisfactionProblem {
coalgebra: FCoalgebra[ExprF, X],
optimize: Boolean = true): RootedLazyTree[X, Total, cats.Id] = {
val lt = new LazyTreeSpec[X](coalgebra, compiler)
val x = lt.get(root)
val totalTrees = new ILazyTree[X, Total, cats.Id] {
override def getExt(k: X): Total[ID] = lt.get(lt.get(k).value)
......
......@@ -70,11 +70,20 @@ case class ProblemContext(intTag: BoxedInt[Literal],
def encode(v: common.Term)(implicit argRewrite: Arg => Tentative[Literal]): Tentative[Literal] =
v match {
case IntLiteral(i) => IntLit(i).asConstant(intTag)
case lv @ LocalVar(_, tpe) if tpe.isSubtypeOf(Type.Integer) =>
assert(tpe == Type.Integer)
case lv @ LocalVar(id, tpe) if tpe.isSubtypeOf(Type.Time) =>
assert(tpe == Type.Time)
val e: Tentative[Int] = id match {
case Type.Start => temporalOrigin
case Type.End => temporalHorizon
case _ =>
Input[Int](Ident(lv)).subjectTo(tp => temporalOrigin <= tp && tp <= temporalHorizon)
}
intBox(intTag, e)
case lv @ LocalVar(_, tpe) if tpe.isSubtypeOf(Type.Integers) =>
assert(tpe == Type.Integers)
Input[Literal](Ident(lv))(intTag)
case lv @ LocalVar(_, tpe) =>
assert(!tpe.isSubtypeOf(Type.Numeric))
assert(!tpe.isSubtypeOf(Type.Reals))
Input[Literal](Ident(lv))(specializedTags(tpe))
case i @ Instance(_, tpe) => ObjLit(i).asConstant(specializedTags(tpe))
case a: Arg => argRewrite(a)
......@@ -88,7 +97,7 @@ case class ProblemContext(intTag: BoxedInt[Literal],
}
def encodeAsInt(e: common.Expr)(
implicit argRewrite: Arg => Tentative[Literal]): Tentative[Int] = {
assert(e.typ.isSubtypeOf(Type.Integer))
assert(e.typ.isSubtypeOf(Type.Integers))
intUnbox(encode(e))
}
def applyOperator(op: BinaryOperator,
......@@ -250,9 +259,10 @@ case class ProblemContext(intTag: BoxedInt[Literal],
}
object ProblemContext {
import Type._
def extract(m: Seq[InModuleBlock]): ProblemContext = {
val objectTypes = m.collect { case TypeDeclaration(t) if t != Type.Integer => t }
val objectSubtypes = mutable.LinkedHashMap[Type, mutable.Set[Type]]()
val objectTypes = m.collect { case TypeDeclaration(t: ObjType) => t }
val objectSubtypes = mutable.LinkedHashMap[ObjType, mutable.Set[ObjType]]()
val instances = m
.collect { case InstanceDeclaration(i) => i }
.map { case i @ Instance(_, t) => (t, i) }
......@@ -268,15 +278,14 @@ object ProblemContext {
case None =>
}
}
val x = mutable.ArrayBuffer[Type]()
val roots = objectTypes.collect { case t @ Type(_, None) => t }.toList
val x = mutable.ArrayBuffer[ObjType]()
def process(t: Type): Unit = {
def process(t: ObjType): Unit = {
assert(!x.contains(t))
x += t
objectSubtypes(t).foreach(process)
}
roots.foreach(process)
process(ObjectTop)
val tmp: List[(ObjLit, Int)] =
x.toList
......@@ -286,10 +295,8 @@ object ProblemContext {
val toIndex = tmp.toMap
assert(toIndex.size == fromIndex.size)
def tagOf(t: Type): TagIsoInt[ObjLit] = {
assert(t != Type.Integer)
def instancesOf(t: Type): Seq[ObjLit] =
def tagOf(t: ObjType): TagIsoInt[ObjLit] = {
def instancesOf(t: ObjType): Seq[ObjLit] =
instances.getOrElse(t, Seq()) ++ objectSubtypes(t).flatMap(instancesOf)
def continuousMinMax(is: Seq[ObjLit]): (Int, Int) = {
val sorted = is.sortBy(t => toIndex(t)).toList
......@@ -359,11 +366,10 @@ object ProblemContext {
}
val memo = mutable.Map[Type, TagIsoInt[ObjLit]]()
val specializedTag = (t: Type) =>
if(t == Type.Integer)
intTag
else
memo.getOrElseUpdate(t, tagOf(t))
val specializedTag: Type => TagIsoInt[_] = {
case _: IIntType => intTag
case t: ObjType => memo.getOrElseUpdate(t, tagOf(t))
}
ProblemContext(intTag.asInstanceOf[BoxedInt[Literal]],
topTag.asInstanceOf[TagIsoInt[Literal]],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment