Custom Search

Monday, February 22, 2010

Phrase groups in clojure

A recent interview with yags* consisted of solving problems and writing demo code for them, which is pretty standard for a gs. One of the problems captured my attention, because it looked like a perfect problem to explore the strengths and weaknesses of clojure.

The problem is simple enough to state: Given a set of lists of words, remove every phrase group that occurs in all the lists. A phrase group is any sequence of three or more words in a row. Two things make this interesting for clojure: 1) it's a purely functional problem: the output result depends only on the input values, and 2) it involves manipulating relatively deep structures - at least in the solution I came up with. That is something working with immutable data structures typically makes painful.

Since I'm still exploring clojure, I chose to do this in Python in the interview. The code here - and in the repository at bitbucket - uses the same algorithm, expressed with the high-level data structures of clojure instead of Python.

First, a couple of short helper functions:
(defn mfilter
"Return a hash-map built by removing entries for which (pred (key entry))
returns false from mapin."

[pred mapin]
(apply hash-map (apply concat (filter #(pred %) mapin))))

(defn enumerate
"Return pairs of an index into sequence, and the value at that index"
[sequence]
(map vector (iterate inc 0) sequence))
As the document string for mfilter says, it returns a copy of a hash-map built from a map by removing entries for which the predicate applied to the key is false. Likewise, enumerate counts the elements in a sequence, starting at 0, and returns pairs of them and the counter.

Note that we only have to worry about phrases of length three. A phrase of length 4 will show up as two phrases of length three with an overlap, and one of length five as three such phrases, all overlapping. So we can ignore phrases longer than three words. And of course, phrases shorter than three words aren't phrase groups, so we ignore them as well. So the list of phrases in an input list is a list triples that has two fewer elements than the input list, as the last two elements don't start phrases.

The data structure used for my solution is a dictionary, index by phrase groups - meaning triples of words. Each entry is also a dictionary, indexed by an input lists index in the list of input lists. The entries in those dictionaries is a list of places that that phrase group starts in the input list. I'm going to call this the phrase dictionary.

That leads us to the first building block function, which accepts a phrase dictionary and a phrase, along with the index of a list and the index of that phrase in the list, and returns the update phrase dictionary:
(defn add-phrase-to-phrase-dict [phrase-dict phrase list-index phrase-index]
(if (or (phrase-dict phrase) (= list-index 0))
(update-in phrase-dict [phrase list-index] conj phrase-index)
phrase-dict))
This uses the clojure function update-in, which is something I don't ever remember seeing in a lisp before (though cl's setf might be close). It finds an element in a nested structure - like the phrase dictionary - using it's first argument, which is a sequence of indices into the structure. That's [phrase list-index] , so this finds the list of places where phrase appears in the input list list-index. Further, if the elements aren't there, it creates the intermediate map needed, and return nils if the last lookup fails. The resulting value gets passed to the second argument, along with any remaining arguments, and the result of that call is used in this position in the new version of the structure that is returned. While this might sound expensive, since everything in the structure is immutable, the two versions can actually share everything but the values along the path to the updated value.

The conditional application of update-in is an optimization. If the phrase we're updating isn't in the dictionary, we want two different behaviors: if this is the first input list, we want to add that phrase to the dictionary. Otherwise, we can ignore the phrase, because it isn't in at least one input list - the first one. Adding it won't change the final result, but will generate more work. So we check for it there, and then only add it if this is the first input list. The alternative case returns the input phrase dictionary unmodified.

We now need to invoke this function on every phrase in every input list. Let's start with a function to invoke it on every phrase in a single input list:
(defn add-list-to-phrase-dict [phrase-dict list-index list]
(reduce (fn [phrase-dict [phrase-index phrase]]
(add-phrase-to-phrase-dict phrase-dict phrase list-index phrase-index))
phrase-dict
(enumerate (map vector list (rest list) (nthnext list 2)))))
As input, we get the phrase dictionary we're going to update, the index of the input list in the list of lists, and the input list itself. Wanting to update a value based on processing values in a list is what the lisp function reduce is for. It takes a function, an initial value, and a list, then calls that function with the initial value and each element in the list. So our function - introduced by (fn - takes two arguments, a phrase dict and a pair of phrase-index and phrase, and returns the result of calling add-phrase-to-phrase-dict with them and the input list values. The initial value is the input phrase dictionary. We generate the phrase list by mapping vector over the input list, and the results of removing the first and then second element from it, giving three-element vectors of three consecutive words - which are our phrases. Again, the primitive does the right thing for us, and stops producing maps when the last list runs out, so the last two words in the initial input list don't start phrases. We pass that list to enumerate to get the index and phrase pairs we need.

Given that, we basically repeat the same construct to deal with all the input lists:
(defn build-phrase-dict [lists]
(let [phrase-dict (reduce (fn [phrase-dict [list-index list]]
(add-list-to-phrase-dict phrase-dict
list-index list))
{}
(enumerate lists))
list-count (count lists)]
(mfilter #(= (count (val %)) list-count) phrase-dict)))
The input is a sequence of lists. That's the sequence we reduce over here, once again passing it to enumerate to generate our indices. This time, the initial value is an empty map. The most important change is that, instead of returning the result of reduce directly, we use the mfilter function defined earlier to remove any elements that don't have as many entries as we had input lists. That's the criteria for inclusion - that the phrase be in every list. If it isn't in some list, then it's dictionary entry won't have an entry for that list, and hence the count won't match the length of the list of lists.

That's the first half the problem - building up the phrase dictionary. We now need to use that to remove the phrases from the input lists. Oddly enough, the phrases themselves don't matter - we just care about their positions. So we write a function that takes the list of phrase positions, and removes them from a list:
(defn remove-phrases-from-list [phrase-starts list]
(mfilter (fn [[x _]] (not-any? #(and (>= x %) (< x (+ % 3))) phrase-starts))
(apply sorted-map (mapcat vector (iterate inc 0) list))))
Here's the second mfilter - this time, it's remove words from the list. We start by creating a sorted map indexed by word position of the words in the list. Then mfilter uses not-any? to check each word's position against the phrase start positions, specifically that the position is not greater than the start position and less than three plus the start posistion. If a position falls into that range for any phrase in the phrase-starts list, it's removed.

So all that's left is printing the result:
(defn remove-shared-phrases [lists]
(let [phrases (apply merge-with concat (vals (build-phrase-dict lists)))]
(doseq [[idx list] (enumerate lists)]
(println (vals (remove-phrases-from-list (phrases idx) list))))))
This actually combines everything so far: building the phrase dictionary from the list, extracting the list of phrase positions from that, using merge-with to concat the lists in each of the element dictionaries into a single list for each file, then using doseq to call println on the result of removing the phrases in each list from the list. At this point, we see why the output map was a sorted-map: that forces the results of each list to print in the proper order.

All in all, this has been an illuminating exercise. Clojure's functions for dealing with deep structures and non-list structures work very well with them, providing what is still a very Lisp-like environment, even though we're not working with lists. In particular, the handling of what in python were special cases is right by default, which feels very much like classic lisp behavior. I'd say Clojure held up very well here.

*) That would be Yet Another Google Spinoff.