Paper Summary - Toward Interpretable Deep Reinforcement Learning with Linear Model U-Trees by Liu et. al. (2018)

8 minute read

Liu, G., Schulte, O., Zhu, W., & Li, Q. (2018, September). Toward interpretable deep reinforcement learning with linear model u-trees. In Joint European Conference on Machine Learning and Knowledge Discovery in Databases (pp. 414-429). Springer, Cham.

My Objectives

To learn about a specific approach which generates explanations for decisions made by RL agents and use it for my own project.

Paper Summary

This paper presents a technique to approximate (mimic) the q-function of RL agent using a novel model called a Linear Model U-Tree (LMUT). It presents an algorithm to learn this model, and to use it to generate explanations in the form of feature importance ranking, rule-extraction for explaining agent behaviour and using feature importance to identify superpixels in image input. The model is computationally evaluated by measuring fidelity of q-value predictions to the original agent, and measuring reward returned by performing in various RL environments.

Methodology

The paper describes a model which approximates the q-function of an RL agent. A q-function $Q(s, a)$ represents the expected maximum utility from executing action $a$ in state $s$. The agent being approximated must be able to output $Q(s, a)$ given $s$ and $a$ for this technique to work.

The model learned is an LMUT. It is based on an existing model called a U-Tree which also approximates the q-function of an RL agent. The novelty of the model is in using linear models in the leaf nodes of the tree instead of constants. The linear models predict the q-value in terms of state features. The usage of linear models instead of constants improves the generalizability of the U-Tree, allowing it to better approximate an agent’s q-function with fewer leaf nodes.

Training

The training algorithm presented in the paper only describes how to learn the weights of the linear models in the leaf nodes, and how to split the leaf nodes of an already existing LMUT. How the initial structure of the tree is learned has not been touched upon.

Dataset

The dataset used for training the LMUT can be obtained during the training process of the RL agent (experience training) or by using a trained RL agent in an environment (active play).

In the experience training setting, records (transitions) of the form $(s_t, a_t, Q(s_t, a_t))$ are collected during the training process of the RL agent, where it is given $s_t$ and $a_t$ as input and learns to predict $Q(s_t, a_t)$ ($s_t$ and $a_t$ are the state and action executed at timestep $t$ respectively). These records serve as the dataset for learning the LMUT model.

In the active play setting, we are given the pre-trained RL agent only. We operate it in an environment (the authors use an $\epsilon$-greedy strategy for selecting actions) and again obtain a dataset of records (transitions) of the form $(s_t, a_t, r_t, s_{t+1}, Q(s_t, a_t))$ where $s_t$, $a_t$ and $Q(s_t, a_t)$ are the same as before, while $r_t$ represents the reward obtained at timestep $t$ and $s_{t+1}$ is the successor state. The dataset obtained from the active play setting is used in a minibatch online-style to learn an LMUT i.e., as transitions are continually generated, groups (minibatches) of some size $B$ are collected to train the LMUT with.

The dataset obtained from active play contains more information than that obtained from experience training. The training algorithm described later can use the extra information like $r_t$ and $s_{t+1}$ that only comes from the active play dataset to build an MDP of the environment. I will use $(s_t, a_t, Q(s_t, a_t))$ to denote a transition for brevity.

Active play offers benefits over experience replay for generating the dataset. It only requires the trained RL agent, which is more commonly available to mimic learners. This is because training effective RL agents for complex environments requires an enormous amount of compute time and resources, which further makes dataset generation using experience replay impractical. Lastly, the experience training data can contain many states which are not visited in the optimal policy, making it difficult for the mimic learner to obtain an optimal return.

The last point did not entirely convince me.

The algorithm for learning the LMUT is divided into 3 phases -

The algorithm described assumes that a U-Tree with feature splits exists such that we can use a transition $(s_t, a_t, Q(s_t, a_t))$ to traverse it down to a leaf node. I assume the U-Tree structure can be learned from the dataset, but it would have been nice if the author’s touched upon this in the paper.

  1. Data Gathering Phase

    In this phase, every transition is used to traverse the tree down to leaf node, where it is added to the transition set of that node. If we have the dataset from the active play setting, we can use the extra information to learn an MDP of the environment. The paper does not describe what can be done with this MDP.

  2. Node Splitting Phase

    In this phase, the weights of the linear models in the leaf nodes are updated using the node’s transition set. The linear model in a leaf node $N$ predicts $Q(s_t, a_t)$ given an input transition $(s_t, a_t)$ in terms of the feature set of $s_t$. The weights are updated using SGD in minibatches. The next step in this phase is to see if we can split the nodes to obtain a better fit to the original RL agent’s q-function. For every leaf node, we test every possible split (in terms of state feature values) and use a splitting criterion to see if the splits are better than the (pre-defined) threshold. If yes, we keep the splits, assign the transitions in the split parent node to the children appropriately, and use SGD again to update the weights of the linear models in the children.

    The authors tried 3 different splitting criteria to implement this phase, namely -

    1. working response of SGD: improve the working response of the parent linear model on the data for its children.
    2. Kolmogorov-Smirnov test: a non-parametric statistical test that measures the differences in empirical cumulative distribution functions between the child data.
    3. Variance Test: generates child nodes whose q-values contain the least variance.

    The authors found the Variance Test to produce good splits with lesser time complexity than the K-S test, and so used it as the main splitting criterion.

    This phase seems to split the leaf nodes only once. It is unclear if we should keep repeating this phase until no nodes can be split.

Explanations

The paper presents two different kinds of explanations that can be generated by this model.

  1. Feature influence ranking

    A global explanation in the form of a feature influence ranking can be generated to inform the user of the relative importance of features in the agent’s policy. These features are state features i.e., they describe a state of the environment. They are calculated using a formula presented in the paper which combines the relative magnitude of the weight of the feature in the linear models of the LMUT, and the reduction in variance of the q-values as a result of splitting by that feature. The exact formula is reproduced from Equation 3 in the paper below -

    \[Inf_f^N = \left( 1 + \frac{|w_{N_f}|}{\sum_{j=1}^{J} |w_{N_j}|} \right) \left( var_N - \sum_{c=1}{C} \frac{Num_c}{\sum_{i=1}^{C} Num_i} var_c \right)\]

    The first term is the multiplier based on the relative weight of the feature $f$ in node $N$. The second term is the difference in the q-value variance in node $N$ and the frequency-weighted q-value variance among the child nodes. The influence of a feature is the sum of the influences of the feature for all nodes split by that feature.

  2. Rule-extraction

    Rules in the form of partition cells are extracted which are constructed from the splitting features of the LMUT. The cells contain the range of feature values, and a Q-vector representing the average Q-value in the cell (I believe it represents the Q-value for each valid action taken in that state).

    How to extract these rules is not adequately described in the paper. Due to the lack of clarity on how to construct the initial U-tree with state feature splits, it is unclear to me how to obtain these rules from an LMUT.

  3. Super-pixel Extraction

    Like feature importance when applied to images where individual pixels are features, this form of explanation simply identifies the most important pixels which contribute to an RL agent’s q-value predictions.

Experimental Setup

Metrics

The authors evaluate the LMUT model along two directions -

  1. Fidelity to the original RL agent

    The LMUT should predict Q-values in close agreement to the predictions made by the original RL agent which is being mimicked. The authors use Mean Absolute Error (MAE) and Root Mean Square Error (RMSE) as metrics.

  2. Matching game-playing performance

    The LMUT should obtain rewards comparable to the original RL agent when operating in the environment itself using actions suggested by the mimicked q-function. This is measured using Average Reward Per Episode (ARPE) which is again a common evaluation metric for mimic models.

Environment

RL environments from the OpenAI Gym toolkit are used, namely Cartpole, Mountain Cart and Flappy Bird.

Benchmarks

The paper uses CART, M5-RT and M5-MT as benchmark models which use the experience replay dataset in a batch learning fashion. For active play, the paper uses FIMT and FIMT-AF.

I don’t understand how the batch learning benchmarks cannot be used in active play using the same minibatch approach used by LMUT. FIMT can update given just one transition, whereas LMUT requires a minibatch to do weight updates. The other batch learning methods can also use a minibatch from the active play dataset stream to perform their updates.

Datasets were generated for each of the environments using the experience replay and active play settings and the performance of the respective benchmarks and LMUT were measured in terms of the metrics mentioned previously. The number of leaves in the learned models were also measured. Consecutive learning was used to plot learning curves for LMUT to measure the rate of convergence in the different environments.

Results & Analysis

LMUT was found to outperform the benchmarks across all metrics. It couldn’t match the performance of the original agent, however. The difference was significant (p < 0.05) as measured by a t-test.

Discussion

The authors describe why LMUT outperforms the benchmarks. They hypothesize that the batch learning benchmarks use a dataset which contain suboptimal states, which don’t lead to optimal rewards. Since they generalize over this dataset, they do not perform as well as LMUT, which gives more importance to more recent states. The recent states would have been visited by a mature RL agent, which means they must be close to optimal. Thus, LMUT tends to give better rewards.

This doesn’t seem like a very convincing explanation since the LMUT training algorithm doesn’t suggest that it favours “recent” states over others. Additionally, using a dataset generated by a mature RL agent in to train a batch learner would eliminate the effects of suboptimal states.

Impressions

The primary difficulty I had with this paper was a lack of clarity on how the learning algorithm worked, primarily how the U-tree is initially learned. Learning this would help me understand the node-splitting phase of the learning algorithm, as well as how the rule-extraction works.

A lot of the explainability of the LMUT model is made by appeal to the “transparent nature” of tree models. There is no attempt to more directly evaluate their interpretability by use of user studies. Perhaps a concerted effort to evaluate XRL technique interpretability was only made more recently.

The technique seems to use the same form of dataset from the original model as is used in Madumal et. al. (2019). It seems the most model-agnostic way to mimic an RL agent is to use the transitions obtained from either the training or the active play setting. I wonder if this approach of approximating the q-function is general enough. Do all RL agents have a q-function which they use to make decisions?

This is notwithstanding that the goal here is to generate useful explanations, not to have an accurate mimic model. The usefulness of the explanations is not really evaluated, so all the explanations provided by the model are just useful curiosities, and do not have quantitative results regarding their usefulness.