When Graph Neural Networks Meet Reinforcement Learning

When Graph Neural Networks Meet Reinforcement Learning

Reinforcement Learning (RL) agents should be able to efficiently generalize to novel situations and transfer their learned skills. Without these properties, such agents would always have to learn from scratch, even though they have already mastered primitive skills that could potentially be leveraged to acquire more complex ones.

Combining primitive skills and building upon them to solve harder tasks is a key challenge within artificial intelligence. In the context of goal-conditioned agents, transfer and adaptibility seem to depend on two key features: the goal space design, and the policy architecture. On the one hand, the goal representation—whether it is learned or predefined—should encapsulate an adequate structure that defines a specific topology in the goal space.

On the other hand, since the behavior of artificial agents does not only depend on how they represent their goals, but also on how they take actions, we investigate Graph Neural Networks (GNNs) as technical tools to model policies in autotelic agents. This choice is also motivated by developmental approaches, as research in psychology shows that humans perceive their world in a structured fashion [Winston 1970; Palmer 1975; Navon 1977; Markman 1989; Kemp and Tenenbaum 2008; Tenenbaum et al. 2011; Battaglia et al. 2016; Battaglia et al. 2018; Godfrey-Smith 2021].

This blog post is organized as follows. First, we start by introducing GNNs as technical tools to endow artificial agents with relational inductive biases. Then, we present an overview on the use of GNNs in the field of RL. Finally, we highlight several limitations of such combination.

Graph Neural Networks

Recently, deep learning methods have been used to solve a significant amount of problems in different domains. Ranging from image classification [Redmon et al. 2016; Ren et al. 2015] and video processing [Zhang et al. 2016] to speech recognition [Hinton et al. 2012] and neural machine translation [Luong et al. 2015; Wu et al. 2017], these methods use parameterized neural networks as building blocks. Consequently, such methods are usually end-to-end, requiring few to no assumptions. They feed their networks with raw streams of data which are usually represented in the Euclidean Space. However, many applications rather represent data in non-Euclidean domains and use graphs with complex relationships and inter-dependencies. Standard usage of deep learning techniques usually struggle with this type of unstructured representations.

Interestingly, research has been interested in leveraging graph-based information using neural networks. Namely, Graph Neural Networks (GNNs) were proposed as computational frameworks that handle unstructured data using neural networks that they share between nodes and edges [Wang et al. 2016; Battaglia et al. 2016; Santoro et al. 2017; Zaheer et al. 2017; Hamrick et al. 2017; Sanchez-Gonzalez et al. 2018; Battaglia et al. 2018; Zambaldi et al. 2018; Wang et al. 2018; Bapst et al. 2019; Li et al. 2019; Colas et al. 2020; Akakzia et al. 2021; Akakzia and Sigaud 2022]. Although these methods are all based on the same idea, they use different techniques depending on how they handle computations within their GNNs’ definition. There exist several surveys that propose different taxonomies for GNNs-based methods [Bronstein et al. 2017; Hamilton et al. 2017; Battaglia et al. 2018; Lee et al. 2018; Wu et al. 2020]. In this blog post, rather than presenting an exhaustive survey of GNNs, our goal is to define the building blocks including definitions and computational schemes. Besides, we focus on applications in RL and present a short overview of standard methods.

Relational Inductive Bias with Graph Neural Networks

First, we propose a definition for the central component of GNNs: the graph.

Graph. A graph is a mathematical structure used to model pairwise relations between objects. More formally, we denote a graph by an ordered pair \(G=(V, E)\), where \(V\) is the set of vertices or nodes—the objects—and \(E\) is the set of edges—the pairwise relations. We denote a single node by \(v_i \in V\), and an edge traveling from node \(v_i\) to node \(v_j\) as \(e_{ij} \in E\). We also define the neighborhood of a node \(v_i\) to be the set of nodes to which \(v_i\) is connected by an edge. Formally, this set is defined as

\[\mathcal{N}(v_i) = \{v_j \in V~|~e_{ij} \in E\}.\]

Finally, we consider some global features which characterize the whole graph, and we denote them by \(u\).

Undirected and Directed Graphs. The definition above suggests that the edges of a graph \(G\) are inherently directed from a source node to a recipient node. In some special scenarios, a graph can be undirected: that is, \(e_{ij} = e_{ji}\) for each pair of nodes \(v_i\) and \(v_j\). In this case, the relation between nodes is said to be symmetric. If the edges are distinguished from their inverted counterparts (\(e_{ij} \neq e_{ji}\)), then the graph is said to be directed.

Graph Input

The input of a graph corresponds to the parsed input features of all its nodes, all its edges and some other global features characterizing the whole system. Active lines of research that are orthogonal to our work are exploring methods that enable the extraction of such parsed features from raw sensory data [Watters et al. 2017; Van Steenkiste et al. 2018; Li et al. 2018; Kipf et al. 2018]. To simplify our study, we suppose the existence of a predefined feature extractor that automatically generates input values for each node and edge. For simplicity, we respectively denote the input features of node \(i\), edge \(i \rightarrow j\) and global features by \(v_i\), \(e_{ij}\) and \(u\).

Graph Output

Depending on the graph structure and the task at hand, the output of the graph can focus on different graph levels. If the functions used to produce this output are modeled by neural networks, then we speak about GNNs.

Node-level. This level focuses on the nodes of the graph. In this scenario, input features including node, edge and global features are used to produce a new embedding for each node. This can be used to perform regression and classification at the level of nodes and learn about the physical dynamics of each object [Battaglia et al. 2016; Chang et al. 2016; Wang et al. 2018; Sanchez-Gonzalez et al. 2018].

Edge-level. This level focuses on the edges of the graph. The output of the computational scheme in this case are the updated features of each node after propagating the information between all the nodes. For instance, it can be used to make decisions about interactions among the different objects [Kipf et al. 2018; Hamrick et al. 2018].

Graph-level. This level focuses on the entire graph. The output corresponds to a global embedding computed after propagating the information between all nodes of the graph. It can be used by embodied agents to produce actions in multi-object scenarios [Akakzia et al. 2021; Akakzia and Sigaud 2022], to answer questions about a visual scene [Santoro et al. 2017] or to extract the global properties molecules in chemistry [Gilmer et al. 2017].

Graph Computation

So far, we have formally defined graphs and distinguished three types of attention-levels which define their output. Thereafter, we explain how exactly the computation of this output is conducted. The computational scheme within GNNs involves two main properties. First, it is based on shared neural networks which are used to compute the updated features of all the nodes and edges. Second, it uses aggregation functions that pool these features in order to produce the output. These two properties provide GNNs with good combinatorial generalization capabilities. In fact, not only it enables good transfer between different nodes and edges (based on the shared networks), but also it leverages permutation invariance (based on the aggregation scheme).

We denote the shared neural networks between the nodes by \(NN_{nodes}\), the shared neural networks between edges by \(NN_{edges}\), and the readout neural network that produces the global output of the GNN by \(NN_{readout}\). Besides, we focus on graph-level output. The full computational scheme is based on three steps: the edge updates, the node updates and the graph readout.

The edge update step. The edge update step consists in using the input features involving each edge \(i \rightarrow j\) to compute its updated features, which we note \(e'_{ij}\). More precisely, we consider the global input feature \(u\), the input features of the source node \(v_i\) and the input features of the recipient node \(v_j\). We use the shared network \(NN_{edges}\) to compute the updated features of all the edges. Formally, the updated features \(e'_{ij}\) of the edge \(i \rightarrow j\) are computed as follows:

\[e'_{ij}~=~NN_{edges}(v_i, v_j, e_{ij}, u).\]

The node update step. The node update step aims at computing the updated features of all the nodes. We note \(v'_{i}\) these updated features for node \(i\). To do so, the input features of the underlying node, the global features as well as the aggregation of the updated features of the incoming edges to \(i\) are considered. The incoming edges to \(i\) correspond to edges whose source nodes are necessarily in the neighborhood of \(i\), \(\mathcal{N}(i)\). The shared network \(NN_{nodes}\) is used in this computation. Formally, the updated features \(v'_{i}\) of the node \(i\) are obtained as follows:

\[v'_{i}~=~NN_{nodes}(v_i, Agg_{i \in \mathcal{N}(i)}(e'_{ij}), u).\]

The graph readout step. The graph readout step computes the global output of the graph. This quantity is obtained by aggregating all the updated features of the nodes within the graph. It uses the readout neural network \(NN_{readout}\). Formally, the output \(o\) of the GNN is computed as follows:

\[o~=~NN_{readout}(Agg_{i \in graph}(v'_{i})).\]

The computational steps we described above can be used in some other order. For example, one can first perform the node update using the input features of edges, then perform the edge updates using the updated nodes features. This choice usually depends on the domain and task at hand. Besides, our descriptions above are categorized within the family of convolutional GNNs [Bruna et al. 2013; Henaff et al. 2015; Defferrard et al. 2016; Kipf and Welling 2016; Levie et al. 2018; Gilmer et al. 2017; Akakzia and Sigaud 2022], which generalize the operation of convolution from grid data to graph data by pooling features of neighbors when updating each node. There exist other categories of GNNs, such as graph auto-encoders [Cao et al. 2016; Wang et al. 2016; Kipf and Welling 2016; Pan et al. 2018; Li et al. 2018], spatio-temporal GNNs [Yu et al. 2017; Li et al. 2017; Seo et al. 2018; Guo et al. 2019] and recurrent GNNs [Scarselli et al. 2005; Gallicchio and Micheli 2010; Li et al. 2015; Dai et al. 2018]. Finally, the aggregation module used to perform node-wise pooling can be either some predefined permutation-invariant function such as sum, max or mean, or a more sophisticated self-attention-based function that learns attention weights for each node [Veličković et al. 2017].

Overview on Graph Neural Networks in RL

Recently, Graph Neural Networks have been widely used in Reinforcement Learning. In fact, they promote sample efficiency, especially in multi-object manipulation domains, where object invariance becomes crucial for generalization. In this paragraph, we introduce an overview over recent works in RL using GNNs. We divide the works in two categories: GNNs used for model-based RL and for model-free RL.

Model-based Reinforcement Learning. The idea of using GNNs in model-based reinforcement learning settings mainly amounts to representing the perceived world of the artificial agents with graphs. Recent papers have been using GNNs to learn prediction models by construction graph representations using the bodies and joints of the agents [Wang et al. 2016; Hamrick et al. 2017; Sanchez-Gonzalez et al. 2018]. This approach is shown to be successful in prediction, system identification and planning. However, these approaches struggle when the structure of the components and joints of the agent are different. For example, they work better on the Swimmer environment than HalfCheetah, since the latter contains more joints corresponding to different components (back leg, front leg, head …). Other approaches use Interaction Networks [Battaglia et al. 2016], which are a particular type of GNNs to implement transition models of the environment which they later use for imagination-based optimization [Hamrick et al. 2017] or planning from scratch [Wang et al. 2016]

Model-free Reinforcement Learning. GNNs are also used in model-free reinforcement learning to model the policy and / or the value function [Wang et al. 2018; Zambaldi et al. 2018; Bapst et al. 2019; Li et al. 2019; Colas et al. 2020; Akakzia et al. 2021]. On the one hand, like the model-based setting, some approaches use them to represent the agent’s body and joints as a graph where the different components interact with each other to produce an action [Wang et al. 2018]. On the other hand, other approaches use it to represent the world in term of separate entities and attempt to capture the relational features between them [Zambaldi et al. 2018; Bapst et al. 2019; Li et al. 2019; Colas et al. 2020; Akakzia et al. 2021].

Limitations

In spite of their generalization capacities provided by their permutation invariance, GNNs still show some limitations to solve some classes of problems such as discriminating between certain non-isomorphic graphs [Kondor and Trivedi 2018]. Moreover, notions like recursion, control flow and conditional iteration are not straightforward to represent with graphs, and might require some domain-specific tweaks (for example, in interpreting abstract syntax trees). In fact, symbolic programs using probabilistic models are shown to work better on these classes of problems [Tenenbaum et al. 2011; Goodman et al. 2014; Lake et al. 2015]. But more importantly, a more pressing question is about the origin of the graph networks that most of the methods work on. In fact, most approaches that use GNNs use graphs with predefined entities corresponding to structured objects. Removing this assumption, it is still unclear how to convert sensory data into more structured graph-like representations. Some lines of active research are exploring these issues [Watters et al. 2017; Van Steenkiste et al. 2018; Li et al. 2018; Kipf et al. 2018].

  1. Winston, P.H. 1970. Learning structural descriptions from examples. None.
  2. Palmer, S.E. 1975. Visual perception and world knowledge: Notes on a model of sensory-cognitive interaction. Explorations in cognition, 279–307.
  3. Navon, D. 1977. Forest before trees: The precedence of global features in visual perception. Cognitive psychology 9, 3, 353–383.
  4. Markman, E.M. 1989. Categorization and naming in children: Problems of induction. mit Press.
  5. Kemp, C. and Tenenbaum, J.B. 2008. The discovery of structural form. Proceedings of the National Academy of Sciences 105, 31, 10687–10692.
  6. Tenenbaum, J.B., Kemp, C., Griffiths, T.L., and Goodman, N.D. 2011. How to grow a mind: Statistics, structure, and abstraction. Science 331, 6022, 1279–1285.
  7. Battaglia, P., Pascanu, R., Lai, M., Jimenez Rezende, D., and others. 2016. Interaction networks for learning about objects, relations and physics. Advances in neural information processing systems 29.
  8. Battaglia, P.W., Hamrick, J.B., Bapst, V., et al. 2018. Relational inductive biases, deep learning, and graph networks. arXiv preprint arXiv:1806.01261.
  9. Godfrey-Smith, P. 2021. Theory and reality. In: Theory and Reality. University of Chicago Press.
  10. Redmon, J., Divvala, S., Girshick, R., and Farhadi, A. 2016. You only look once: Unified, real-time object detection. Proceedings of the IEEE conference on computer vision and pattern recognition, 779–788.
  11. Ren, S., He, K., Girshick, R., and Sun, J. 2015. Faster r-cnn: Towards real-time object detection with region proposal networks. Advances in neural information processing systems 28.
  12. Zhang, W., Xu, L., Li, Z., Lu, Q., and Liu, Y. 2016. A deep-intelligence framework for online video processing. IEEE Software 33, 2, 44–51.
  13. Hinton, G., Deng, L., Yu, D., et al. 2012. Deep neural networks for acoustic modeling in speech recognition: The shared views of four research groups. IEEE Signal processing magazine 29, 6, 82–97.
  14. Luong, M.-T., Pham, H., and Manning, C.D. 2015. Effective approaches to attention-based neural machine translation. arXiv preprint arXiv:1508.04025.
  15. Wu, S., Zhang, D., Yang, N., Li, M., and Zhou, M. 2017. Sequence-to-dependency neural machine translation. Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), 698–707.
  16. Wang, J.X., Kurth-Nelson, Z., Tirumala, D., et al. 2016. Learning to reinforcement learn. arXiv preprint arXiv:1611.05763.
  17. Santoro, A., Raposo, D., Barrett, D.G., et al. 2017. A simple neural network module for relational reasoning. Advances in neural information processing systems 30.
  18. Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R.R., and Smola, A.J. 2017. Deep sets. Advances in neural information processing systems, 3391–3401.
  19. Hamrick, J.B., Ballard, A.J., Pascanu, R., Vinyals, O., Heess, N., and Battaglia, P.W. 2017. Metacontrol for adaptive imagination-based optimization. arXiv preprint arXiv:1705.02670.
  20. Sanchez-Gonzalez, A., Heess, N., Springenberg, J.T., et al. 2018. Graph networks as learnable physics engines for inference and control. International Conference on Machine Learning, PMLR, 4470–4479.
  21. Zambaldi, V., Raposo, D., Santoro, A., et al. 2018. Relational deep reinforcement learning. arXiv preprint arXiv:1806.01830.
  22. Wang, T., Liao, R., Ba, J., and Fidler, S. 2018. NerveNet: Learning Structured Policy with Graph Neural Networks. International Conference on Learning Representations.
  23. Bapst, V., Sanchez-Gonzalez, A., Doersch, C., et al. 2019. Structured agents for physical construction. International Conference on Machine Learning, PMLR, 464–474.
  24. Li, R., Jabri, A., Darrell, T., and Agrawal, P. 2019. Towards Practical Multi-Object Manipulation using Relational Reinforcement Learning. ArXiv preprint abs/1912.11032.
  25. Colas, C., Karch, T., Lair, N., et al. 2020. Language as a Cognitive Tool to Imagine Goals in Curiosity Driven Exploration. Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual.
  26. Akakzia, A., Colas, C., Oudeyer, P.-Y., Chetouani, M., and Sigaud, O. 2021. Grounding Language to Autonomously-Acquired Skills via Goal Generation. 9th International Conference on Learning Representations, ICLR 2021, Virtual Event, Austria, May 3-7, 2021, OpenReview.net.
  27. Akakzia, A. and Sigaud, O. 2022. Learning Object-Centered Autotelic Behaviors with Graph Neural Networks. arXiv preprint arXiv:2204.05141.
  28. Bronstein, M.M., Bruna, J., LeCun, Y., Szlam, A., and Vandergheynst, P. 2017. Geometric deep learning: going beyond euclidean data. IEEE Signal Processing Magazine 34, 4, 18–42.
  29. Hamilton, W.L., Ying, R., and Leskovec, J. 2017. Representation learning on graphs: Methods and applications. arXiv preprint arXiv:1709.05584.
  30. Lee, J.B., Rossi, R., and Kong, X. 2018. Graph classification using structural attention. Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, 1666–1674.
  31. Wu, Z., Pan, S., Chen, F., Long, G., Zhang, C., and Philip, S.Y. 2020. A comprehensive survey on graph neural networks. IEEE transactions on neural networks and learning systems 32, 1, 4–24.
  32. Watters, N., Zoran, D., Weber, T., Battaglia, P., Pascanu, R., and Tacchetti, A. 2017. Visual interaction networks: Learning a physics simulator from video. Advances in neural information processing systems 30.
  33. Van Steenkiste, S., Chang, M., Greff, K., and Schmidhuber, J. 2018. Relational neural expectation maximization: Unsupervised discovery of objects and their interactions. arXiv preprint arXiv:1802.10353.
  34. Li, Y., Vinyals, O., Dyer, C., Pascanu, R., and Battaglia, P. 2018. Learning deep generative models of graphs. arXiv preprint arXiv:1803.03324.
  35. Kipf, T., Fetaya, E., Wang, K.-C., Welling, M., and Zemel, R. 2018. Neural relational inference for interacting systems. International Conference on Machine Learning, PMLR, 2688–2697.
  36. Chang, M.B., Ullman, T., Torralba, A., and Tenenbaum, J.B. 2016. A compositional object-based approach to learning physical dynamics. arXiv preprint arXiv:1612.00341.
  37. Hamrick, J.B., Allen, K.R., Bapst, V., et al. 2018. Relational inductive bias for physical construction in humans and machines. ArXiv preprint abs/1806.01203.
  38. Gilmer, J., Schoenholz, S.S., Riley, P.F., Vinyals, O., and Dahl, G.E. 2017. Neural message passing for quantum chemistry. arXiv preprint arXiv:1704.01212.
  39. Bruna, J., Zaremba, W., Szlam, A., and LeCun, Y. 2013. Spectral networks and locally connected networks on graphs. arXiv preprint arXiv:1312.6203.
  40. Henaff, M., Bruna, J., and LeCun, Y. 2015. Deep convolutional networks on graph-structured data. arXiv preprint arXiv:1506.05163.
  41. Defferrard, M., Bresson, X., and Vandergheynst, P. 2016. Convolutional neural networks on graphs with fast localized spectral filtering. Advances in neural information processing systems 29.
  42. Kipf, T.N. and Welling, M. 2016. Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907.
  43. Levie, R., Monti, F., Bresson, X., and Bronstein, M.M. 2018. Cayleynets: Graph convolutional neural networks with complex rational spectral filters. IEEE Transactions on Signal Processing 67, 1, 97–109.
  44. Cao, S., Lu, W., and Xu, Q. 2016. Deep neural networks for learning graph representations. Proceedings of the AAAI Conference on Artificial Intelligence.
  45. Wang, D., Cui, P., and Zhu, W. 2016. Structural deep network embedding. Proceedings of the 22nd ACM SIGKDD international conference on Knowledge discovery and data mining, 1225–1234.
  46. Kipf, T.N. and Welling, M. 2016. Variational graph auto-encoders. arXiv preprint arXiv:1611.07308.
  47. Pan, S., Hu, R., Long, G., Jiang, J., Yao, L., and Zhang, C. 2018. Adversarially regularized graph autoencoder for graph embedding. arXiv preprint arXiv:1802.04407.
  48. Yu, B., Yin, H., and Zhu, Z. 2017. Spatio-temporal graph convolutional networks: A deep learning framework for traffic forecasting. arXiv preprint arXiv:1709.04875.
  49. Li, Y., Yu, R., Shahabi, C., and Liu, Y. 2017. Diffusion convolutional recurrent neural network: Data-driven traffic forecasting. arXiv preprint arXiv:1707.01926.
  50. Seo, Y., Defferrard, M., Vandergheynst, P., and Bresson, X. 2018. Structured sequence modeling with graph convolutional recurrent networks. International conference on neural information processing, Springer, 362–373.
  51. Guo, S., Lin, Y., Feng, N., Song, C., and Wan, H. 2019. Attention based spatial-temporal graph convolutional networks for traffic flow forecasting. Proceedings of the AAAI conference on artificial intelligence, 922–929.
  52. Scarselli, F., Yong, S.L., Gori, M., Hagenbuchner, M., Tsoi, A.C., and Maggini, M. 2005. Graph neural networks for ranking web pages. The 2005 IEEE/WIC/ACM International Conference on Web Intelligence (WI’05), IEEE, 666–672.
  53. Gallicchio, C. and Micheli, A. 2010. Graph echo state networks. The 2010 international joint conference on neural networks (IJCNN), IEEE, 1–8.
  54. Li, Y., Tarlow, D., Brockschmidt, M., and Zemel, R. 2015. Gated graph sequence neural networks. arXiv preprint arXiv:1511.05493.
  55. Dai, H., Kozareva, Z., Dai, B., Smola, A., and Song, L. 2018. Learning steady-states of iterative algorithms over graphs. International conference on machine learning, PMLR, 1106–1114.
  56. Veličković, P., Cucurull, G., Casanova, A., Romero, A., Lio, P., and Bengio, Y. 2017. Graph attention networks. arXiv preprint arXiv:1710.10903.
  57. Kondor, R. and Trivedi, S. 2018. On the generalization of equivariance and convolution in neural networks to the action of compact groups. International Conference on Machine Learning, PMLR, 2747–2755.
  58. Goodman, N.D., Tenenbaum, J.B., and Gerstenberg, T. 2014. Concepts in a probabilistic language of thought. Center for Brains, Minds and Machines (CBMM).
  59. Lake, B.M., Salakhutdinov, R., and Tenenbaum, J.B. 2015. Human-level concept learning through probabilistic program induction. Science 350, 6266, 1332–1338.
Share: Twitter Facebook
Ahmed Akakzia's Picture

About Ahmed Akakzia

Ahmed is a final year PhD candidate, passionate about artificial intelligence

Paris, France

Comments