Software Architecture 10 min read

Designing Type-Safe Query DSLs in Scala

Build compile-time safe database queries with zero runtime string errors. Learn how to create fluent query APIs that catch typos, type mismatches, and schema changes at compile time using Scala's type system.

Designing Type-Safe Query DSLs in Scala

Introduction: Beyond Query Strings

MongoDB queries in most applications look like this:

collection.find(
  BsonDocument("""{
    "userId": "12345",
    "status": { "$in": ["active", "pending"] },
    "createdAt": { "$gte": { "$date": "2024-01-01T00:00:00Z" } }
  }""")
)

The problems are obvious:

  • Typos in field names go undetected until runtime
  • Type mismatches (string where int expected) cause cryptic errors
  • Refactoring is dangerous - rename a field and miss a query string
  • No IDE support - autocomplete, jump-to-definition don’t work

This article presents a type-safe query DSL that catches these errors at compile time while maintaining the expressiveness of MongoDB queries. After millions of queries in production, this approach has eliminated an entire class of bugs.

The Case for Type-Safe Queries

Consider this scenario:

// V1: Product has simple pricing
case class Product(
  id: String,
  name: String,
  price: BigDecimal
)

// Query for expensive products
collection.find(BsonDocument("""{ "price": { "$gt": 100 } }"""))

Six months later, pricing becomes complex:

// V2: Product has tiered pricing
case class Product(
  id: String,
  name: String,
  pricing: PricingTiers
)

case class PricingTiers(
  basePrice: BigDecimal,
  volumeDiscounts: List[Discount]
)

Now all those "price" queries fail at runtime. With a type-safe DSL, the compiler catches these immediately:

// Compile error: field 'price' no longer exists
Product.PriceField  // Does not compile after refactoring

Designing Field References with Types

The foundation is a type-parameterized field reference:

case class Field[Model, FieldType](
  name: String,
  converter: StrConverter[FieldType]
) {
  // Generate field path for nested queries
  def path: String = name

  // Type-safe value conversion
  def toBson(value: FieldType): BsonValue =
    converter.toBson(value)
}

object Field {
  // Convenience constructor with implicit converter
  def apply[M, T](name: String)(implicit conv: StrConverter[T]): Field[M, T] =
    new Field[M, T](name, conv)
}

Type Converters

The StrConverter type class handles conversion between Scala types and BSON:

trait StrConverter[T] {
  def fromString(str: String): T
  def toString(value: T): String
  def toBson(value: T): BsonValue
  def fromBson(bson: BsonValue): T
}

object StrConverter {
  implicit val stringConverter: StrConverter[String] = new StrConverter[String] {
    override def fromString(str: String): String = str
    override def toString(value: String): String = value
    override def toBson(value: String): BsonValue = BsonString(value)
    override def fromBson(bson: BsonValue): String = bson.asString().getValue
  }

  implicit val longConverter: StrConverter[Long] = new StrConverter[Long] {
    override def fromString(str: String): Long = str.toLong
    override def toString(value: Long): String = value.toString
    override def toBson(value: Long): BsonValue = BsonInt64(value)
    override def fromBson(bson: BsonValue): Long = bson.asInt64().getValue
  }

  implicit val dateTimeConverter: StrConverter[DateTime] = new StrConverter[DateTime] {
    override def fromString(str: String): DateTime = DateTime.parse(str)
    override def toString(value: DateTime): String = value.toString()
    override def toBson(value: DateTime): BsonValue = BsonDateTime(value.getMillis)
    override def fromBson(bson: BsonValue): DateTime = new DateTime(bson.asDateTime().getValue)
  }

  implicit def optionConverter[T](implicit inner: StrConverter[T]): StrConverter[Option[T]] =
    new StrConverter[Option[T]] {
      override def fromString(str: String): Option[T] = Some(inner.fromString(str))
      override def toString(value: Option[T]): String = value.map(inner.toString).getOrElse("")
      override def toBson(value: Option[T]): BsonValue =
        value.map(inner.toBson).getOrElse(BsonNull())
      override def fromBson(bson: BsonValue): Option[T] =
        if (bson.isNull) None else Some(inner.fromBson(bson))
    }
}

Model Field Definitions

Define fields once in companion objects:

case class Product(
  id: String,
  name: String,
  price: BigDecimal,
  category: String,
  createdAt: DateTime,
  tagsOpt: Option[List[String]]
)

object Product {
  val IdField = Field[Product, String]("_id")
  val NameField = Field[Product, String]("name")
  val PriceField = Field[Product, BigDecimal]("price")
  val CategoryField = Field[Product, String]("category")
  val CreatedAtField = Field[Product, DateTime]("created_at")
  val TagsOptField = Field[Product, Option[List[String]]]("tags")

  // Nested field access
  val PricingBaseField = Field[Product, BigDecimal]("pricing.basePrice")
}

Building Operator Abstractions

MongoDB has many query operators. We model them as a sealed trait hierarchy:

sealed trait Operator

// Binary operators: field OP value
sealed trait BinaryOperator extends Operator {
  def mongoOp: String
  def toBson[T](field: Field[_, T], value: T): BsonDocument
}

case object Equal extends BinaryOperator {
  override def mongoOp: String = "$eq"
  override def toBson[T](field: Field[_, T], value: T): BsonDocument =
    BsonDocument(field.path -> field.toBson(value))
}

case object GreaterThan extends BinaryOperator {
  override def mongoOp: String = "$gt"
  override def toBson[T](field: Field[_, T], value: T): BsonDocument =
    BsonDocument(field.path -> BsonDocument(mongoOp -> field.toBson(value)))
}

case object LessThan extends BinaryOperator {
  override def mongoOp: String = "$lt"
  override def toBson[T](field: Field[_, T], value: T): BsonDocument =
    BsonDocument(field.path -> BsonDocument(mongoOp -> field.toBson(value)))
}

case object GreaterThanOrEqual extends BinaryOperator {
  override def mongoOp: String = "$gte"
  override def toBson[T](field: Field[_, T], value: T): BsonDocument =
    BsonDocument(field.path -> BsonDocument(mongoOp -> field.toBson(value)))
}

// Unary operators: field OP
sealed trait UnaryOperator extends Operator {
  def mongoOp: String
  def toBson[T](field: Field[_, T]): BsonDocument
}

case object Exists extends UnaryOperator {
  override def mongoOp: String = "$exists"
  override def toBson[T](field: Field[_, T]): BsonDocument =
    BsonDocument(field.path -> BsonDocument(mongoOp -> BsonBoolean(true)))
}

case object IsNull extends UnaryOperator {
  override def mongoOp: String = "$eq"
  override def toBson[T](field: Field[_, T]): BsonDocument =
    BsonDocument(field.path -> BsonNull())
}

// Many operators: field IN [value1, value2, ...]
sealed trait ManyOperator extends Operator {
  def mongoOp: String
  def toBson[T](field: Field[_, T], values: Seq[T]): BsonDocument
}

case object In extends ManyOperator {
  override def mongoOp: String = "$in"
  override def toBson[T](field: Field[_, T], values: Seq[T]): BsonDocument =
    BsonDocument(
      field.path -> BsonDocument(
        mongoOp -> BsonArray.fromIterable(values.map(field.toBson))
      )
    )
}

case object NotIn extends ManyOperator {
  override def mongoOp: String = "$nin"
  override def toBson[T](field: Field[_, T], values: Seq[T]): BsonDocument =
    BsonDocument(
      field.path -> BsonDocument(
        mongoOp -> BsonArray.fromIterable(values.map(field.toBson))
      )
    )
}

Creating Fluent Query APIs

Now we build the query DSL that chains operations:

class ReadQuery[Model](
  collection: MongoCollection[Model],
  filters: Seq[BsonDocument] = Seq.empty,
  sortOpt: Option[BsonDocument] = None,
  limitOpt: Option[Int] = None,
  skipOpt: Option[Int] = None
) {

  // Binary operators
  def where[T](field: Field[Model, T], operator: BinaryOperator, value: T): ReadQuery[Model] =
    new ReadQuery(
      collection,
      filters :+ operator.toBson(field, value),
      sortOpt,
      limitOpt,
      skipOpt
    )

  // Unary operators
  def where[T](field: Field[Model, T], operator: UnaryOperator): ReadQuery[Model] =
    new ReadQuery(
      collection,
      filters :+ operator.toBson(field),
      sortOpt,
      limitOpt,
      skipOpt
    )

  // Many operators
  def where[T](
    field: Field[Model, T],
    operator: ManyOperator,
    value1: T,
    remaining: T*
  ): ReadQuery[Model] =
    new ReadQuery(
      collection,
      filters :+ operator.toBson(field, value1 +: remaining),
      sortOpt,
      limitOpt,
      skipOpt
    )

  // Sorting
  def orderBy[T](field: Field[Model, T], ascending: Boolean = true): ReadQuery[Model] =
    new ReadQuery(
      collection,
      filters,
      Some(BsonDocument(field.path -> BsonInt32(if (ascending) 1 else -1))),
      limitOpt,
      skipOpt
    )

  // Pagination
  def limit(n: Int): ReadQuery[Model] =
    new ReadQuery(collection, filters, sortOpt, Some(n), skipOpt)

  def skip(n: Int): ReadQuery[Model] =
    new ReadQuery(collection, filters, sortOpt, limitOpt, Some(n))

  // Execution
  def fetchOne: Future[Option[Model]] = {
    val query = buildFindIterable
    query.limit(1).first().toFuture().map(Option(_))
  }

  def fetchList: Future[Seq[Model]] = {
    val query = buildFindIterable
    query.toFuture()
  }

  def fetchAll: Future[Seq[Model]] = fetchList

  def count: Future[Long] = {
    val filter = combineFilters
    collection.countDocuments(filter).toFuture()
  }

  // Build the MongoDB FindIterable
  private def buildFindIterable: FindIterable[Model] = {
    val filter = combineFilters
    var query = collection.find(filter)

    sortOpt.foreach(sort => query = query.sort(sort))
    limitOpt.foreach(limit => query = query.limit(limit))
    skipOpt.foreach(skip => query = query.skip(skip))

    query
  }

  private def combineFilters: BsonDocument = {
    if (filters.isEmpty) {
      BsonDocument()
    } else if (filters.size == 1) {
      filters.head
    } else {
      BsonDocument("$and" -> BsonArray.fromIterable(filters))
    }
  }
}

Usage Examples

Now queries become type-safe and readable:

// Simple equality
Product.repository
  .where(Product.CategoryField, Equal, "electronics")
  .fetchList

// Range queries
Product.repository
  .where(Product.PriceField, GreaterThan, BigDecimal(100))
  .where(Product.PriceField, LessThan, BigDecimal(1000))
  .orderBy(Product.PriceField, ascending = true)
  .fetchList

// IN queries
Product.repository
  .where(Product.CategoryField, In, "electronics", "computers", "gadgets")
  .fetchList

// NULL checks
Product.repository
  .where(Product.TagsOptField, IsNull)
  .fetchList

// Pagination
Product.repository
  .where(Product.CategoryField, Equal, "electronics")
  .orderBy(Product.CreatedAtField, ascending = false)
  .limit(20)
  .skip(40)  // Page 3
  .fetchList

// Complex queries
Product.repository
  .where(Product.CategoryField, Equal, "electronics")
  .where(Product.PriceField, GreaterThanOrEqual, BigDecimal(50))
  .where(Product.TagsOptField, Exists)
  .orderBy(Product.CreatedAtField, ascending = false)
  .limit(100)
  .fetchList

Handling Aggregations and Joins

MongoDB aggregation pipelines are more complex:

class AggregateQueryBuilder[Model](
  collection: MongoCollection[Model],
  stages: Seq[BsonDocument] = Seq.empty
) {

  def matchStage(filters: BsonDocument): AggregateQueryBuilder[Model] =
    new AggregateQueryBuilder(
      collection,
      stages :+ BsonDocument("$match" -> filters)
    )

  def groupBy[T](
    field: Field[Model, T],
    accumulators: (String, BsonDocument)*
  ): AggregateQueryBuilder[Model] = {
    val groupDoc = BsonDocument(
      "_id" -> BsonString(s"$$${field.path}")
    ) ++ BsonDocument(accumulators: _*)

    new AggregateQueryBuilder(
      collection,
      stages :+ BsonDocument("$group" -> groupDoc)
    )
  }

  def addFields(fields: (String, BsonValue)*): AggregateQueryBuilder[Model] =
    new AggregateQueryBuilder(
      collection,
      stages :+ BsonDocument("$addFields" -> BsonDocument(fields: _*))
    )

  def sort[T](field: Field[Model, T], ascending: Boolean = true): AggregateQueryBuilder[Model] =
    new AggregateQueryBuilder(
      collection,
      stages :+ BsonDocument(
        "$sort" -> BsonDocument(field.path -> BsonInt32(if (ascending) 1 else -1))
      )
    )

  def limit(n: Int): AggregateQueryBuilder[Model] =
    new AggregateQueryBuilder(
      collection,
      stages :+ BsonDocument("$limit" -> BsonInt32(n))
    )

  def skip(n: Int): AggregateQueryBuilder[Model] =
    new AggregateQueryBuilder(
      collection,
      stages :+ BsonDocument("$skip" -> BsonInt32(n))
    )

  // Lookup (join)
  def lookup(
    from: String,
    localField: String,
    foreignField: String,
    as: String
  ): AggregateQueryBuilder[Model] =
    new AggregateQueryBuilder(
      collection,
      stages :+ BsonDocument(
        "$lookup" -> BsonDocument(
          "from" -> BsonString(from),
          "localField" -> BsonString(localField),
          "foreignField" -> BsonString(foreignField),
          "as" -> BsonString(as)
        )
      )
    )

  def execute[R](implicit decoder: Decoder[R]): Future[Seq[R]] = {
    collection.aggregate[R](stages).toFuture()
  }
}

Aggregation Example

// Count products by category, show top 10
Product.repository
  .aggregate
  .groupBy(
    Product.CategoryField,
    "count" -> BsonDocument("$sum" -> BsonInt32(1)),
    "avgPrice" -> BsonDocument("$avg" -> BsonString("$price"))
  )
  .sort(Product.CategoryField, ascending = true)
  .limit(10)
  .execute[CategoryStats]

Element Match for Nested Arrays

Querying nested arrays requires special handling:

case class Order(
  id: String,
  customerId: String,
  items: List[OrderItem],
  createdAt: DateTime
)

case class OrderItem(
  productId: String,
  quantity: Int,
  price: BigDecimal
)

object Order {
  val ItemsField = Field[Order, List[OrderItem]]("items")

  // Element match query builder
  def itemsWithProductId(productId: String): BsonDocument =
    BsonDocument(
      "items" -> BsonDocument(
        "$elemMatch" -> BsonDocument(
          "productId" -> BsonString(productId)
        )
      )
    )

  def itemsWithQuantityGreaterThan(quantity: Int): BsonDocument =
    BsonDocument(
      "items" -> BsonDocument(
        "$elemMatch" -> BsonDocument(
          "quantity" -> BsonDocument("$gt" -> BsonInt32(quantity))
        )
      )
    )
}

// Usage
Order.repository
  .where(Order.ItemsField, ElemMatch, Order.itemsWithProductId("prod-123"))
  .fetchList

Pagination Strategies

Different approaches for different use cases:

Skip/Limit (Simple but Slow for Large Offsets)

sealed trait PaginationStrategy
case class Skip(offset: Int) extends PaginationStrategy
case class PageNumber(page: Int, pageSize: Int) extends PaginationStrategy {
  def toSkip: Skip = Skip((page - 1) * pageSize)
}

def fetchPage(strategy: PaginationStrategy, pageSize: Int): Future[Seq[Product]] =
  strategy match {
    case Skip(offset) =>
      Product.repository
        .orderBy(Product.CreatedAtField, ascending = false)
        .limit(pageSize)
        .skip(offset)
        .fetchList

    case PageNumber(page, size) =>
      fetchPage(PageNumber(page, size).toSkip, size)
  }

Cursor-Based (Efficient for Large Datasets)

def fetchAfterCursor(
  lastSeenId: Option[String],
  pageSize: Int
): Future[Seq[Product]] = {
  val query = Product.repository

  val withCursor = lastSeenId match {
    case Some(id) =>
      query.where(Product.IdField, GreaterThan, id)
    case None =>
      query
  }

  withCursor
    .orderBy(Product.IdField, ascending = true)
    .limit(pageSize)
    .fetchList
}

Real-World Usage Patterns

Repository Pattern

class ProductRepository(collection: MongoCollection[Product]) {
  def findById(id: String): Future[Option[Product]] =
    new ReadQuery(collection)
      .where(Product.IdField, Equal, id)
      .fetchOne

  def findByCategory(category: String, page: Int, pageSize: Int): Future[Seq[Product]] =
    new ReadQuery(collection)
      .where(Product.CategoryField, Equal, category)
      .orderBy(Product.CreatedAtField, ascending = false)
      .limit(pageSize)
      .skip((page - 1) * pageSize)
      .fetchList

  def findExpensiveProducts(minPrice: BigDecimal): Future[Seq[Product]] =
    new ReadQuery(collection)
      .where(Product.PriceField, GreaterThanOrEqual, minPrice)
      .orderBy(Product.PriceField, ascending = false)
      .fetchList

  def searchByName(query: String): Future[Seq[Product]] =
    collection
      .find(BsonDocument("name" -> BsonDocument("$regex" -> BsonString(query), "$options" -> BsonString("i"))))
      .toFuture()

  def countByCategory(category: String): Future[Long] =
    new ReadQuery(collection)
      .where(Product.CategoryField, Equal, category)
      .count
}

Error Handling

def findProductSafely(id: String): Future[Either[String, Product]] =
  repository
    .findById(id)
    .map {
      case Some(product) => Right(product)
      case None => Left(s"Product not found: $id")
    }
    .recover {
      case ex: MongoException => Left(s"Database error: ${ex.getMessage}")
      case ex => Left(s"Unexpected error: ${ex.getMessage}")
    }

When to Use This Pattern

✅ Use When:

  • Building long-lived systems with evolving schemas
  • Multiple developers need to query the same collections
  • Refactoring is frequent
  • Type safety is more important than query flexibility

❌ Avoid When:

  • Queries are one-off or exploratory
  • Performance requires hand-tuned native queries
  • Team is uncomfortable with Scala type system
  • Using a different database (this is MongoDB-specific)

Conclusion

Type-safe query DSLs transform database access from error-prone string manipulation to compiler-verified operations. The benefits compound over time:

  1. Refactoring confidence - Rename fields safely across codebase
  2. Type safety - Compiler catches field name typos and type mismatches
  3. IDE support - Autocomplete, jump-to-definition, find-usages all work
  4. Self-documenting - Field definitions serve as schema documentation
  5. Gradual adoption - Can coexist with string-based queries

The implementation requires upfront investment in DSL design, but pays dividends through reduced bugs, faster development, and easier maintenance. After years in production, this approach has proven its worth across codebases of all sizes.

Further Reading

Back to Blog

Related Posts

View All Posts »