Scala: Traversing a graph in a functional way
I recently played around with graphs and tried to implement Depth-First Search in a functional and recursive way.
In the following examples we are talking about the following simple directed unweighted graph with vertices a to e:
Each vertex is represented by an object of the following class:
class Vertex() {
  val name: Char
  val edges: Set[Vertex]
}Let's start with an approach that is neither functional nor recursive, but at least straight-forward:
def dfsMutableIterative(start: Vertex): Set[Vertex] = {
  var current: Vertex = start
  val found: mutable.Set[Vertex] = mutable.Set[Vertex]()
  val stack: mutable.Stack[Vertex] = mutable.Stack[Vertex]()
  stack.push(current)
  while (!stack.isEmpty) {
    current = stack.pop()
    if (!found.contains(current)) {
      found += current
      for (next <- current.edges) {
        stack.push(next)
      }
    }
  }
  found.toSet
}Here, we are working with two mutable data structures, one to track the list of vertices we have found during traversal, plus a stack that is needed to build the list of yet-to-visit vertices we encounter during traversal.
Another approach is recursive, but still uses mutable state:
def dfsMutableRecursive(start: Vertex): Set[Vertex] = {
  val found: mutable.Set[Vertex] = mutable.Set[Vertex]()
  def recurse(current: Vertex): Unit = {
    found += current
    for (next <- current.edges) {
      if (!found.contains(next)) {
        recurse(next)
      }
    }
  }
  recurse(start)
  found.toSet
}How can we solve this completely functional? We need to keep track of the vertices we have already encountered, but at the same time, we don't want to update any mutable data structures.
One way to achieve this is to use a for loop with recursion, where we pass a growing list of encountered vertices upon each recursive call:
def dfsFunctional(current: Vertex, acc: List[Vertex]): Set[Vertex] = {
  (for (next <- current.edges if !acc.contains(next))
    yield dfsFunctional(next, current +: acc)).flatten + current
}Now, I don't know about you, by I'm always having a hard time wrapping my head around what exactly is happening during execution when recursion is involved. In order to really grasp how the above code works, I have drawn the following chart:
Well, on a first look, this might make things even worse, but bear with me. If we follow the arrows, everything works out fine.
When we call the function with vertex object a and an empty list as its parameters (e.g. dfsFunctional(a, Nil), given that a is the val holding the object for vertex 'a'), then the following happens:
We start at recursion level 0, that is, we just called our function. As parameter current, we passed vertex a, and our acc (for accumulator) is a list with no entries.
We now enter the for loop and iterate over the list of all edges of current (which we retrieve by calling current.edges). The for loop skips vertices that are already contained in the accumulator, but right now it doesn't contain any vertices. The first edge of vertex a is directed at vertex b, thus we recurse with current set to b, and acc set to List(a), because we prepend the existing accumulator with our current vertex (current +: acc).
In the chart, this leads to recursive call 1, abbreviated as 1: call(b,(a)) on the first black arrow, which is the equivalent to the code call dfsFunctional(b, List(a)).
Vertex b doesn't have any edges, and thus, the for loop is not entered. Instead, the call immediately returns with a return value of List(b) (abbreviated as (b) on the first green arrow). Note how the .flatten + current code is not part of the for loop! Instead, it is applied to the Sets yielded by the for loop, and evaluates to a set containing all vertices from recursive calls (if they occur) plus the vertex that is currently handled.
With this we return back to recursion level 0, and because following b the next vertex pointed at by an edge from a is c, we recurse again, which leads to 2: call(c,(a)), that is, dfsFunctional is now called with c as the current vertex, and the accumulated list of vertices we know about at level 0 is still only List(a).
From c we have to recurse again, because c points at e, which leads to recursion level 2. Here things get a bit more interesting, because e points to f and d, and d points back to c - that is, the c -> e -> d -> c part of the graph creates a loop.
But that's not a problem thanks to our accumulator, as can be seen in call 5 on level 3: Here, we investigate d, and would recurse to c, but this is prevented because c is already in the accumulator.
Each recursion returns a set of the results of further recursions (if any) plus the current vertex, and if you follow thw green arrows, you see how this adds up: calls 4 and 5 on level 3 return f and d, respectively, which are returned on level 2 together with e. These are then joined by c on level 1. Level 0 started two recursions, with call 1 resulting in b and call 2 resulting in f, d, e, c. Because level 0 returns these two sets plus the initial vertex a, we end up with a set of all vertices, b, f, d, e, c, a.
The recursive approach using for works well and doesn't use mutable state, but compared to dfsMutableRecursive, it is not as time efficient. For the given incomplex graph, this doesn't show, but for a slightly more complex graph like this one:
it quickly shows: dfsMutableRecursive needs to recurse 6 times, while dfsFunctional needs to recurse 14 times.
But we can further improve our functional approach. A recursive solution exists that is as efficient as the one using mutable state, where we make use of foldLeft. As a bonus, we get rid of the flatten call that was necessary in dfsFunctional:
def dfsFunctionalFold(current: Vertex, acc: Set[Vertex]): Set[Vertex] = {
  current.edges.foldLeft(acc) {
    (results, next) =>
      if (results.contains(next)) results
      else dfsFunctionalFold(next, results + current)
  } + current
}
On the complex graph, this also needs to recurse only 6 times, and is - just as dfsMutableRecursive - stable in regards to the number of vertices: For any given 6-vertices graph, the number of recursions is 6, and grows to 13 for a 13-vertices graph, while dfsFunctional grows to 30 recursions.