Commit 163ff060 authored by Arthur Bit-Monnot's avatar Arthur Bit-Monnot

[pddl] Prepare addition of action factory

parent 74a715b0
......@@ -6,23 +6,24 @@ import dahu.planning.model.{common, full}
import dahu.planning.model.full._
import dahu.utils.errors._
import dahu.planning.pddl.Utils._
import dahu.planning.pddl.Ctx._
import dahu.planning.pddl.Resolver._
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.language.implicitConversions
abstract class Factory {}
abstract class Factory {
class ModelFactory {
var model = PddlPredef.baseModel
def context: Ctx
implicit val resolver: Resolver = new Resolver {
def scope: Scope = context.scope
implicit val ctx = new Ctx {
override def id(name: String): Id = Id(common.RootScope, name)
override def typeOf(name: String): Type =
model.findType(name).getOrElse(unexpected(s"unknown type: $name"))
context.findType(name).getOrElse(unexpected(s"unknown type: $name"))
override def variable(name: String): StaticExpr = model.findDeclaration(name) match {
override def variable(name: String): StaticExpr = context.findDeclaration(name) match {
case Some(v: VarDeclaration[_]) => CommonTerm(v.variable)
case _ => unexpected(s"unknown variable: $name")
}
......@@ -30,19 +31,34 @@ class ModelFactory {
override def nextId(): String = dahu.planning.model.reservedPrefix + next().toString
}
val translators = mutable.HashMap[String, FunctionCompat]()
def getTranslator(name: String): FunctionCompat = translators(name)
def recordFunction(pddlPred: NamedTypedList): Unit = {
def getTranslator(name: String): FunctionCompat
def id(name: String): Id = Id(resolver.scope, name)
protected def asFluent(name: String, args: Seq[String]): Fluent =
Fluent(getTranslator(name).model, args.map(resolver.variable))
def hasType(name: String): Boolean = context.findType(name).nonEmpty
}
class ModelFactory(predef: PddlPredef.type) extends Factory {
private var model = predef.baseModel
override def context: Model = model
private val translators = mutable.HashMap[String, FunctionCompat]()
override def getTranslator(name: String): FunctionCompat = translators(name)
def rec(block: full.InModuleBlock): Unit = model = (model + block).get
private def recordFunction(pddlPred: NamedTypedList): Unit = {
val t = FunctionCompat(pddlPred)
translators += ((t.name, t))
rec(FunctionDeclaration(t.model))
}
def rec(block: full.InModuleBlock): Unit = model = (model + block).get
def id(name: String): Id = Id(common.RootScope, name)
def recordType(tpe: ast.Tpe): Unit = {
private def recordType(tpe: ast.Tpe): Unit = {
val ast.Tpe(name, parent) = tpe
assert(!hasType(name), s"type already recorded: $name")
assert(parent.forall(hasType), s"parent not recorded: $parent")
......@@ -62,13 +78,11 @@ class ModelFactory {
}
}
def recordInstance(name: String, tpe: String): Unit = {
private def recordInstance(name: String, tpe: String): Unit = {
rec(InstanceDeclaration(Instance(id(name), typeOf(tpe))))
}
private def asFluent(name: String, args: Seq[String]): Fluent =
Fluent(getTranslator(name).model, args.map(ctx.variable))
def recordInitialState(e: Exp): Unit = {
private def recordInitialState(e: Exp): Unit = {
val assertion = e match {
case ast.AssertionOnFunction(funcName) =>
getTranslator(funcName).effect(e)
......@@ -76,7 +90,7 @@ class ModelFactory {
rec(TemporallyQualifiedAssertion(Equals(Interval(predef.Start, predef.Start)), assertion))
}
def recordGoal(e: Exp): Unit = e match {
private def recordGoal(e: Exp): Unit = e match {
case ast.And(goals) =>
goals.foreach(recordGoal)
case ast.AssertionOnFunction(name) =>
......@@ -88,8 +102,6 @@ class ModelFactory {
))
}
def hasType(name: String): Boolean = model.findType(name).nonEmpty
def loadDomain(dom: Domain): Unit = {
val types = dom.getTypes.asScala.map {
case ast.ReadTpe(tpe: ast.Tpe) => tpe
......@@ -126,5 +138,5 @@ class ModelFactory {
recordGoal(pb.getGoal)
}
def result: Model = model
def result: Model = context
}
......@@ -4,7 +4,7 @@ import dahu.planning.model.common
import dahu.planning.model.common._
import dahu.planning.model.full._
import dahu.planning.pddl.Utils._
import dahu.planning.pddl.Ctx._
import dahu.planning.pddl.Resolver._
import dahu.utils.errors._
import fr.uga.pddl4j.parser._
......@@ -20,7 +20,7 @@ abstract class FunctionCompat() {
}
object FunctionCompat {
def apply(pddl: NamedTypedList)(implicit ctx: Ctx): FunctionCompat = {
def apply(pddl: NamedTypedList)(implicit ctx: Resolver): FunctionCompat = {
pddl.getTypes.asScala match {
case Seq() => new DefaultPredicate(pddl)
case Seq(tpe) => new DefaultFunction(pddl)
......@@ -29,7 +29,7 @@ object FunctionCompat {
}
}
class DefaultPredicate(pddl: NamedTypedList)(implicit ctx: Ctx) extends FunctionCompat {
class DefaultPredicate(pddl: NamedTypedList)(implicit ctx: Resolver) extends FunctionCompat {
override val name: String = pddl.getName.getImage
private val tpe = pddl.getTypes.asScala match {
case Seq() => PddlPredef.Boolean
......@@ -63,7 +63,7 @@ class DefaultPredicate(pddl: NamedTypedList)(implicit ctx: Ctx) extends Function
}
}
class DefaultFunction(pddl: NamedTypedList)(implicit ctx: Ctx) extends FunctionCompat {
class DefaultFunction(pddl: NamedTypedList)(implicit ctx: Resolver) extends FunctionCompat {
override val name: String = pddl.getName.getImage
private val tpe = pddl.getTypes.asScala match {
......
......@@ -20,7 +20,7 @@ object Main extends App {
println(dom)
// println(pb)
val factory = new ModelFactory()
val factory = new ModelFactory(PddlPredef)
factory.loadDomain(dom)
factory.loadProblem(pb)
......
package dahu.planning.pddl
import dahu.planning.model.common.{Id, Type}
import dahu.planning.model.common.{Id, Scope, Type}
import dahu.planning.model.full.StaticExpr
trait Ctx {
trait Resolver {
def scope: Scope
def typeOf(name: String): Type
def id(name: String): Id
def variable(name: String): StaticExpr
......@@ -11,9 +12,9 @@ trait Ctx {
def nextId(): String
}
object Ctx {
def typeOf(name: String)(implicit ctx: Ctx): Type = ctx.typeOf(name)
object Resolver {
def typeOf(name: String)(implicit ctx: Resolver): Type = ctx.typeOf(name)
def id(name: String)(implicit ctx: Ctx): Id = ctx.id(name)
def id(name: String)(implicit ctx: Resolver): Id = ctx.id(name)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment