treewizard.py
Go to the documentation of this file.
00001 """ @package antlr3.tree
00002 @brief ANTLR3 runtime package, treewizard module
00003 
00004 A utility module to create ASTs at runtime.
00005 See <http://www.antlr.org/wiki/display/~admin/2007/07/02/Exploring+Concept+of+TreeWizard> for an overview. Note that the API of the Python implementation is slightly different.
00006 
00007 """
00008 
00009 # begin[licence]
00010 #
00011 # [The "BSD licence"]
00012 # Copyright (c) 2005-2008 Terence Parr
00013 # All rights reserved.
00014 #
00015 # Redistribution and use in source and binary forms, with or without
00016 # modification, are permitted provided that the following conditions
00017 # are met:
00018 # 1. Redistributions of source code must retain the above copyright
00019 #    notice, this list of conditions and the following disclaimer.
00020 # 2. Redistributions in binary form must reproduce the above copyright
00021 #    notice, this list of conditions and the following disclaimer in the
00022 #    documentation and/or other materials provided with the distribution.
00023 # 3. The name of the author may not be used to endorse or promote products
00024 #    derived from this software without specific prior written permission.
00025 #
00026 # THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
00027 # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
00028 # OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
00029 # IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
00030 # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
00031 # NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
00032 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
00033 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00034 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
00035 # THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 #
00037 # end[licence]
00038 
00039 from constants import INVALID_TOKEN_TYPE
00040 from tokens import CommonToken
00041 from tree import CommonTree, CommonTreeAdaptor
00042 
00043 
00044 def computeTokenTypes(tokenNames):
00045     """
00046     Compute a dict that is an inverted index of
00047     tokenNames (which maps int token types to names).
00048     """
00049 
00050     if tokenNames is None:
00051         return {}
00052 
00053     return dict((name, type) for type, name in enumerate(tokenNames))
00054 
00055 
00056 ## token types for pattern parser
00057 EOF = -1
00058 BEGIN = 1
00059 END = 2
00060 ID = 3
00061 ARG = 4
00062 PERCENT = 5
00063 COLON = 6
00064 DOT = 7
00065 
00066 class TreePatternLexer(object):
00067     def __init__(self, pattern):
00068         ## The tree pattern to lex like "(A B C)"
00069         self.pattern = pattern
00070 
00071         ## Index into input string
00072         self.p = -1
00073 
00074         ## Current char
00075         self.c = None
00076 
00077         ## How long is the pattern in char?
00078         self.n = len(pattern)
00079 
00080         ## Set when token type is ID or ARG
00081         self.sval = None
00082 
00083         self.error = False
00084 
00085         self.consume()
00086 
00087 
00088     __idStartChar = frozenset(
00089         'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ_'
00090         )
00091     __idChar = __idStartChar | frozenset('0123456789')
00092     
00093     def nextToken(self):
00094         self.sval = ""
00095         while self.c != EOF:
00096             if self.c in (' ', '\n', '\r', '\t'):
00097                 self.consume()
00098                 continue
00099 
00100             if self.c in self.__idStartChar:
00101                 self.sval += self.c
00102                 self.consume()
00103                 while self.c in self.__idChar:
00104                     self.sval += self.c
00105                     self.consume()
00106 
00107                 return ID
00108 
00109             if self.c == '(':
00110                 self.consume()
00111                 return BEGIN
00112 
00113             if self.c == ')':
00114                 self.consume()
00115                 return END
00116 
00117             if self.c == '%':
00118                 self.consume()
00119                 return PERCENT
00120 
00121             if self.c == ':':
00122                 self.consume()
00123                 return COLON
00124 
00125             if self.c == '.':
00126                 self.consume()
00127                 return DOT
00128 
00129             if self.c == '[': # grab [x] as a string, returning x
00130                 self.consume()
00131                 while self.c != ']':
00132                     if self.c == '\\':
00133                         self.consume()
00134                         if self.c != ']':
00135                             self.sval += '\\'
00136 
00137                         self.sval += self.c
00138 
00139                     else:
00140                         self.sval += self.c
00141 
00142                     self.consume()
00143 
00144                 self.consume()
00145                 return ARG
00146 
00147             self.consume()
00148             self.error = True
00149             return EOF
00150 
00151         return EOF
00152 
00153 
00154     def consume(self):
00155         self.p += 1
00156         if self.p >= self.n:
00157             self.c = EOF
00158 
00159         else:
00160             self.c = self.pattern[self.p]
00161 
00162 
00163 class TreePatternParser(object):
00164     def __init__(self, tokenizer, wizard, adaptor):
00165         self.tokenizer = tokenizer
00166         self.wizard = wizard
00167         self.adaptor = adaptor
00168         self.ttype = tokenizer.nextToken() # kickstart
00169 
00170 
00171     def pattern(self):
00172         if self.ttype == BEGIN:
00173             return self.parseTree()
00174 
00175         elif self.ttype == ID:
00176             node = self.parseNode()
00177             if self.ttype == EOF:
00178                 return node
00179 
00180             return None # extra junk on end
00181 
00182         return None
00183 
00184 
00185     def parseTree(self):
00186         if self.ttype != BEGIN:
00187             return None
00188 
00189         self.ttype = self.tokenizer.nextToken()
00190         root = self.parseNode()
00191         if root is None:
00192             return None
00193 
00194         while self.ttype in (BEGIN, ID, PERCENT, DOT):
00195             if self.ttype == BEGIN:
00196                 subtree = self.parseTree()
00197                 self.adaptor.addChild(root, subtree)
00198 
00199             else:
00200                 child = self.parseNode()
00201                 if child is None:
00202                     return None
00203 
00204                 self.adaptor.addChild(root, child)
00205 
00206         if self.ttype != END:
00207             return None
00208 
00209         self.ttype = self.tokenizer.nextToken()
00210         return root
00211 
00212 
00213     def parseNode(self):
00214         # "%label:" prefix
00215         label = None
00216         
00217         if self.ttype == PERCENT:
00218             self.ttype = self.tokenizer.nextToken()
00219             if self.ttype != ID:
00220                 return None
00221 
00222             label = self.tokenizer.sval
00223             self.ttype = self.tokenizer.nextToken()
00224             if self.ttype != COLON:
00225                 return None
00226             
00227             self.ttype = self.tokenizer.nextToken() # move to ID following colon
00228 
00229         # Wildcard?
00230         if self.ttype == DOT:
00231             self.ttype = self.tokenizer.nextToken()
00232             wildcardPayload = CommonToken(0, ".")
00233             node = WildcardTreePattern(wildcardPayload)
00234             if label is not None:
00235                 node.label = label
00236             return node
00237 
00238         # "ID" or "ID[arg]"
00239         if self.ttype != ID:
00240             return None
00241 
00242         tokenName = self.tokenizer.sval
00243         self.ttype = self.tokenizer.nextToken()
00244         
00245         if tokenName == "nil":
00246             return self.adaptor.nil()
00247 
00248         text = tokenName
00249         # check for arg
00250         arg = None
00251         if self.ttype == ARG:
00252             arg = self.tokenizer.sval
00253             text = arg
00254             self.ttype = self.tokenizer.nextToken()
00255 
00256         # create node
00257         treeNodeType = self.wizard.getTokenType(tokenName)
00258         if treeNodeType == INVALID_TOKEN_TYPE:
00259             return None
00260 
00261         node = self.adaptor.createFromType(treeNodeType, text)
00262         if label is not None and isinstance(node, TreePattern):
00263             node.label = label
00264 
00265         if arg is not None and isinstance(node, TreePattern):
00266             node.hasTextArg = True
00267 
00268         return node
00269 
00270 
00271 class TreePattern(CommonTree):
00272     """
00273     When using %label:TOKENNAME in a tree for parse(), we must
00274     track the label.
00275     """
00276 
00277     def __init__(self, payload):
00278         CommonTree.__init__(self, payload)
00279 
00280         self.label = None
00281         self.hasTextArg = None
00282         
00283 
00284     def toString(self):
00285         if self.label is not None:
00286             return '%' + self.label + ':' + CommonTree.toString(self)
00287         
00288         else:
00289             return CommonTree.toString(self)
00290 
00291 
00292 class WildcardTreePattern(TreePattern):
00293     pass
00294 
00295 
00296 class TreePatternTreeAdaptor(CommonTreeAdaptor):
00297     """This adaptor creates TreePattern objects for use during scan()"""
00298 
00299     def createWithPayload(self, payload):
00300         return TreePattern(payload)
00301 
00302 
00303 class TreeWizard(object):
00304     """
00305     Build and navigate trees with this object.  Must know about the names
00306     of tokens so you have to pass in a map or array of token names (from which
00307     this class can build the map).  I.e., Token DECL means nothing unless the
00308     class can translate it to a token type.
00309 
00310     In order to create nodes and navigate, this class needs a TreeAdaptor.
00311 
00312     This class can build a token type -> node index for repeated use or for
00313     iterating over the various nodes with a particular type.
00314 
00315     This class works in conjunction with the TreeAdaptor rather than moving
00316     all this functionality into the adaptor.  An adaptor helps build and
00317     navigate trees using methods.  This class helps you do it with string
00318     patterns like "(A B C)".  You can create a tree from that pattern or
00319     match subtrees against it.
00320     """
00321 
00322     def __init__(self, adaptor=None, tokenNames=None, typeMap=None):
00323         self.adaptor = adaptor
00324         if typeMap is None:
00325             self.tokenNameToTypeMap = computeTokenTypes(tokenNames)
00326 
00327         else:
00328             if tokenNames is not None:
00329                 raise ValueError("Can't have both tokenNames and typeMap")
00330 
00331             self.tokenNameToTypeMap = typeMap
00332 
00333 
00334     def getTokenType(self, tokenName):
00335         """Using the map of token names to token types, return the type."""
00336 
00337         try:
00338             return self.tokenNameToTypeMap[tokenName]
00339         except KeyError:
00340             return INVALID_TOKEN_TYPE
00341 
00342 
00343     def create(self, pattern):
00344         """
00345         Create a tree or node from the indicated tree pattern that closely
00346         follows ANTLR tree grammar tree element syntax:
00347         
00348         (root child1 ... child2).
00349         
00350         You can also just pass in a node: ID
00351          
00352         Any node can have a text argument: ID[foo]
00353         (notice there are no quotes around foo--it's clear it's a string).
00354         
00355         nil is a special name meaning "give me a nil node".  Useful for
00356         making lists: (nil A B C) is a list of A B C.
00357         """
00358         
00359         tokenizer = TreePatternLexer(pattern)
00360         parser = TreePatternParser(tokenizer, self, self.adaptor)
00361         return parser.pattern()
00362 
00363 
00364     def index(self, tree):
00365         """Walk the entire tree and make a node name to nodes mapping.
00366         
00367         For now, use recursion but later nonrecursive version may be
00368         more efficient.  Returns a dict int -> list where the list is
00369         of your AST node type.  The int is the token type of the node.
00370         """
00371 
00372         m = {}
00373         self._index(tree, m)
00374         return m
00375 
00376 
00377     def _index(self, t, m):
00378         """Do the work for index"""
00379 
00380         if t is None:
00381             return
00382 
00383         ttype = self.adaptor.getType(t)
00384         elements = m.get(ttype)
00385         if elements is None:
00386             m[ttype] = elements = []
00387 
00388         elements.append(t)
00389         for i in range(self.adaptor.getChildCount(t)):
00390             child = self.adaptor.getChild(t, i)
00391             self._index(child, m)
00392 
00393 
00394     def find(self, tree, what):
00395         """Return a list of matching token.
00396 
00397         what may either be an integer specifzing the token type to find or
00398         a string with a pattern that must be matched.
00399         
00400         """
00401         
00402         if isinstance(what, (int, long)):
00403             return self._findTokenType(tree, what)
00404 
00405         elif isinstance(what, basestring):
00406             return self._findPattern(tree, what)
00407 
00408         else:
00409             raise TypeError("'what' must be string or integer")
00410 
00411 
00412     def _findTokenType(self, t, ttype):
00413         """Return a List of tree nodes with token type ttype"""
00414 
00415         nodes = []
00416 
00417         def visitor(tree, parent, childIndex, labels):
00418             nodes.append(tree)
00419 
00420         self.visit(t, ttype, visitor)
00421 
00422         return nodes
00423 
00424 
00425     def _findPattern(self, t, pattern):
00426         """Return a List of subtrees matching pattern."""
00427         
00428         subtrees = []
00429         
00430         # Create a TreePattern from the pattern
00431         tokenizer = TreePatternLexer(pattern)
00432         parser = TreePatternParser(tokenizer, self, TreePatternTreeAdaptor())
00433         tpattern = parser.pattern()
00434         
00435         # don't allow invalid patterns
00436         if (tpattern is None or tpattern.isNil()
00437             or isinstance(tpattern, WildcardTreePattern)):
00438             return None
00439 
00440         rootTokenType = tpattern.getType()
00441 
00442         def visitor(tree, parent, childIndex, label):
00443             if self._parse(tree, tpattern, None):
00444                 subtrees.append(tree)
00445                 
00446         self.visit(t, rootTokenType, visitor)
00447 
00448         return subtrees
00449 
00450 
00451     def visit(self, tree, what, visitor):
00452         """Visit every node in tree matching what, invoking the visitor.
00453 
00454         If what is a string, it is parsed as a pattern and only matching
00455         subtrees will be visited.
00456         The implementation uses the root node of the pattern in combination
00457         with visit(t, ttype, visitor) so nil-rooted patterns are not allowed.
00458         Patterns with wildcard roots are also not allowed.
00459 
00460         If what is an integer, it is used as a token type and visit will match
00461         all nodes of that type (this is faster than the pattern match).
00462         The labels arg of the visitor action method is never set (it's None)
00463         since using a token type rather than a pattern doesn't let us set a
00464         label.
00465         """
00466 
00467         if isinstance(what, (int, long)):
00468             self._visitType(tree, None, 0, what, visitor)
00469 
00470         elif isinstance(what, basestring):
00471             self._visitPattern(tree, what, visitor)
00472 
00473         else:
00474             raise TypeError("'what' must be string or integer")
00475         
00476               
00477     def _visitType(self, t, parent, childIndex, ttype, visitor):
00478         """Do the recursive work for visit"""
00479         
00480         if t is None:
00481             return
00482 
00483         if self.adaptor.getType(t) == ttype:
00484             visitor(t, parent, childIndex, None)
00485 
00486         for i in range(self.adaptor.getChildCount(t)):
00487             child = self.adaptor.getChild(t, i)
00488             self._visitType(child, t, i, ttype, visitor)
00489 
00490 
00491     def _visitPattern(self, tree, pattern, visitor):
00492         """
00493         For all subtrees that match the pattern, execute the visit action.
00494         """
00495 
00496         # Create a TreePattern from the pattern
00497         tokenizer = TreePatternLexer(pattern)
00498         parser = TreePatternParser(tokenizer, self, TreePatternTreeAdaptor())
00499         tpattern = parser.pattern()
00500         
00501         # don't allow invalid patterns
00502         if (tpattern is None or tpattern.isNil()
00503             or isinstance(tpattern, WildcardTreePattern)):
00504             return
00505 
00506         rootTokenType = tpattern.getType()
00507 
00508         def rootvisitor(tree, parent, childIndex, labels):
00509             labels = {}
00510             if self._parse(tree, tpattern, labels):
00511                 visitor(tree, parent, childIndex, labels)
00512                 
00513         self.visit(tree, rootTokenType, rootvisitor)
00514         
00515 
00516     def parse(self, t, pattern, labels=None):
00517         """
00518         Given a pattern like (ASSIGN %lhs:ID %rhs:.) with optional labels
00519         on the various nodes and '.' (dot) as the node/subtree wildcard,
00520         return true if the pattern matches and fill the labels Map with
00521         the labels pointing at the appropriate nodes.  Return false if
00522         the pattern is malformed or the tree does not match.
00523 
00524         If a node specifies a text arg in pattern, then that must match
00525         for that node in t.
00526         """
00527 
00528         tokenizer = TreePatternLexer(pattern)
00529         parser = TreePatternParser(tokenizer, self, TreePatternTreeAdaptor())
00530         tpattern = parser.pattern()
00531 
00532         return self._parse(t, tpattern, labels)
00533 
00534 
00535     def _parse(self, t1, tpattern, labels):
00536         """
00537         Do the work for parse. Check to see if the tpattern fits the
00538         structure and token types in t1.  Check text if the pattern has
00539         text arguments on nodes.  Fill labels map with pointers to nodes
00540         in tree matched against nodes in pattern with labels.
00541         """
00542         
00543         # make sure both are non-null
00544         if t1 is None or tpattern is None:
00545             return False
00546 
00547         # check roots (wildcard matches anything)
00548         if not isinstance(tpattern, WildcardTreePattern):
00549             if self.adaptor.getType(t1) != tpattern.getType():
00550                 return False
00551 
00552             # if pattern has text, check node text
00553             if (tpattern.hasTextArg
00554                 and self.adaptor.getText(t1) != tpattern.getText()):
00555                 return False
00556 
00557         if tpattern.label is not None and labels is not None:
00558             # map label in pattern to node in t1
00559             labels[tpattern.label] = t1
00560 
00561         # check children
00562         n1 = self.adaptor.getChildCount(t1)
00563         n2 = tpattern.getChildCount()
00564         if n1 != n2:
00565             return False
00566 
00567         for i in range(n1):
00568             child1 = self.adaptor.getChild(t1, i)
00569             child2 = tpattern.getChild(i)
00570             if not self._parse(child1, child2, labels):
00571                 return False
00572 
00573         return True
00574 
00575 
00576     def equals(self, t1, t2, adaptor=None):
00577         """
00578         Compare t1 and t2; return true if token types/text, structure match
00579         exactly.
00580         The trees are examined in their entirety so that (A B) does not match
00581         (A B C) nor (A (B C)). 
00582         """
00583 
00584         if adaptor is None:
00585             adaptor = self.adaptor
00586 
00587         return self._equals(t1, t2, adaptor)
00588 
00589 
00590     def _equals(self, t1, t2, adaptor):
00591         # make sure both are non-null
00592         if t1 is None or t2 is None:
00593             return False
00594 
00595         # check roots
00596         if adaptor.getType(t1) != adaptor.getType(t2):
00597             return False
00598 
00599         if adaptor.getText(t1) != adaptor.getText(t2):
00600             return False
00601         
00602         # check children
00603         n1 = adaptor.getChildCount(t1)
00604         n2 = adaptor.getChildCount(t2)
00605         if n1 != n2:
00606             return False
00607 
00608         for i in range(n1):
00609             child1 = adaptor.getChild(t1, i)
00610             child2 = adaptor.getChild(t2, i)
00611             if not self._equals(child1, child2, adaptor):
00612                 return False
00613 
00614         return True


rve_interface_gen
Author(s): Josh Faust
autogenerated on Wed Dec 11 2013 14:31:00