kripkomat/astvisitors.py

614 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
from pycparser import c_ast, c_generator
from itertools import pairwise
from dataclasses import dataclass
def path_to_str(p):
return "->".join([str(n.__class__.__name__) for n in p])
class FuncDefVisitor(c_ast.NodeVisitor):
def __init__(self):
self.func_table = {}
def visit_FuncDef(self, node):
self.func_table[node.decl.name] = node
class FuncCallVisitor(c_ast.NodeVisitor):
def __init__(self):
self.func_calls = []
def visit_FuncCall(self, node):
self.func_calls.append(node.children()[0][1].name)
class EnumDefVisitor(c_ast.NodeVisitor):
def __init__(self, name):
super().__init__()
self._name = name
self.enums = {}
def visit_Enum(self, node):
self.enums[self._name] = node
class EnumTypedefVisitor(c_ast.NodeVisitor):
def __init__(self):
self.enums = {}
def visit_Typedef(self, node):
ev = EnumDefVisitor(node.name)
ev.visit(node)
self.enums = {**self.enums, **ev.enums}
class EnumVisitor(c_ast.NodeVisitor):
def __init__(self):
super().__init__()
self.enum_names = []
def visit_Enumerator(self, node):
self.enum_names.append(node.name)
class PathVisitor(object):
def __init__(self):
self._method_cache = {}
def visit(self, node, path = None):
if path is None:
path = []
visitor = self._method_cache.get(node.__class__.__name__, None)
if visitor is None:
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
self._method_cache[node.__class__.__name__] = visitor
return visitor(node, path)
def generic_visit(self, node, path):
path = path.copy()
path.append(node)
for c in node:
self.visit(c, path)
class FuncCallDeferPathVisitor(PathVisitor):
def __init__(self, func_table):
super().__init__()
self.func_table = func_table
def visit_FuncCall(self, node, path):
fcall = node.name.name
if fcall in self.func_table:
self.visit(self.func_table[fcall], path)
def check_node_equivalence(lhs, rhs):
return str(lhs) == str(rhs)
def check_structref_equivalence(lhs:c_ast.StructRef, rhs:c_ast.StructRef):
if lhs.name.name != rhs.name.name:
return False
if lhs.field.name != rhs.field.name:
return False
return True
@dataclass
class PathStateAssignment:
path:list[c_ast.Node]
node:c_ast.Assignment
state:str
class StateAssignmentVisitor(FuncCallDeferPathVisitor):
def __init__(self, func_table, variable):
super().__init__(func_table)
self.variable = variable
self.assignments = []
def visit_Assignment(self, n, path):
if check_structref_equivalence(self.variable, n.lvalue):
if isinstance(n.rvalue, c_ast.ID):
self.assignments.append(PathStateAssignment(path, n, n.rvalue.name))
elif isinstance(n.rvalue, c_ast.TernaryOp):
#print("TOP", n.rvalue)
pass
@dataclass
class PathAssignment:
path:list[c_ast.Node]
node:c_ast.Assignment
class AssignmentVisitor(FuncCallDeferPathVisitor):
def __init__(self, func_table):
super().__init__(func_table)
self.assignments = []
def visit_Assignment(self, n, path):
self.assignments.append(PathAssignment(path, n))
class PathAnalyzer(object):
def __init__(self):
self._method_cache = {}
def analyze(self, path):
trace = []
for node in path:
self._invoke_visitor(node, trace)
trace.append(node)
def _invoke_visitor(self, node, trace):
visitor = self._method_cache.get(node.__class__.__name__, None)
if visitor is None:
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, None)
self._method_cache[node.__class__.__name__] = visitor
if visitor is not None:
visitor(node, trace)
@dataclass
class GeneralizedCase:
exprs:list[c_ast.Node]
stmts:list[c_ast.Node]
class ConditionalPathAnalyzer(PathAnalyzer):
def __init__(self):
super().__init__()
self.condition_chain = []
def visit_If(self, node, trace):
p = trace[-1]
#print("If", node)
if node.iftrue == p:
#print("->iftrue")
self.condition_chain.append(node.cond)
elif node.iffalse == p:
#print("->iffalse")
neg = c_ast.UnaryOp('!', node.cond)
self.condition_chain.append(neg)
else:
pass
def visit_Switch(self, node, trace):
p_case = trace[-2]
expr_fallthrough = []
gencases = []
cond = node.cond
block_items = node.stmt.block_items
parent_case_idx = None
#print("Switch", cond)
#print("P:", p_case)
#print("CI", block_items)
# lump together fallthrough cases
for case in block_items:
if p_case == case:
parent_case_idx = len(gencases)
if not isinstance(case, c_ast.Default):
expr_fallthrough.append(case.expr)
if len(case.stmts) != 0 and isinstance(case.stmts[-1], c_ast.Break):
gencases.append(GeneralizedCase(expr_fallthrough, case.stmts))
expr_fallthrough = []
#print("Generalized P-Case", gencases[parent_case_idx].exprs)
# does not account for default, which needs the entire set of cases
# checked for unequals and and-ed
eqops = [c_ast.BinaryOp('==', cond, expr) for expr in gencases[parent_case_idx].exprs]
if len(eqops) == 1:
self.condition_chain.append(eqops[0])
elif len(eqops) > 1:
orop = eqops[0]
for expr in eqops[1:len(eqops)]:
orop = c_ast.BinaryOp('||', expr, orop)
self.condition_chain.append(orop)
else:
raise Exception("Unexpected number of expressions")
class ContainsOneOfIdVisitor(c_ast.NodeVisitor):
def __init__(self, ids):
self.ids = ids
self.hit = False
self.name = None
def visit_ID(self, node):
if node.name in self.ids:
self.hit = True
self.name = node.name
def visit_TypedID(self, node):
return self.visit_ID(node)
@dataclass
class StateTransition:
state_from:str
state_to:str
condition:c_ast.Node
class TypedefGatherer(c_ast.NodeVisitor):
def __init__(self):
self.typedefs = {}
def visit_Typedef(self, node):
self.typedefs[node.name] = node
class AstTransformer(object):
def __init__(self):
self._method_cache = {}
def transform(self, node):
visitor = self._method_cache.get(node.__class__.__name__, None)
if visitor is None:
method = 'transform_' + node.__class__.__name__
visitor = getattr(self, method, self._transform_generic)
self._method_cache[node.__class__.__name__] = visitor
return visitor(node)
def _transform_Node(self, node):
new_c = []
for c_name in node.__slots__[0:-1]:
c = getattr(node, c_name)
new_c.append(self.transform(c))
node_constructor = node.__class__
return node_constructor(*new_c)
def _transform_generic(self, node):
if isinstance(node, c_ast.Node):
return self._transform_Node(node)
elif isinstance(node, list):
return [self.transform(x) for x in node]
else:
return node
class AstPathTransformer(object):
def __init__(self):
self._method_cache = {}
def transform(self, node, path = None):
if path is None:
path = [node]
else:
path = path.copy()
path.append(node)
visitor = self._method_cache.get(node.__class__.__name__, None)
if visitor is None:
method = 'transform_' + node.__class__.__name__
visitor = getattr(self, method, self._transform_generic)
self._method_cache[node.__class__.__name__] = visitor
return visitor(node, path)
def _transform_Node(self, node, path):
new_c = []
for c_name in node.__slots__[0:-1]:
c = getattr(node, c_name)
new_c.append(self.transform(c, path))
node_constructor = node.__class__
return node_constructor(*new_c)
def _transform_generic(self, node, path):
if isinstance(node, c_ast.Node):
return self._transform_Node(node, path)
elif isinstance(node, list):
return [self.transform(x, path) for x in node]
else:
return node
class TypedID(c_ast.ID):
__slots__ = ('name', 'type', 'coord')
def __init__(self, name, type=None, coord=None):
super().__init__(name, coord)
self.type = type
def children(self):
nodelist = []
return tuple(nodelist)
def __iter__(self):
return
yield
attr_names = ('name', 'type' )
class CGenerator(c_generator.CGenerator):
def __init__(self):
super().__init__()
def visit_TypedID(self, n):
return n.name
class TypeDeclFinder(c_ast.NodeVisitor):
def __init__(self, declname):
self.declname = declname
self.decl = None
def visit_TypeDecl(self, node):
if node.declname == self.declname:
self.decl = node
class DeclFinder(c_ast.NodeVisitor):
def __init__(self, declname):
self.declname = declname
self.decl = None
def visit_Decl(self, node):
if node.name == self.declname:
self.decl = node
class IdentifierTypeFinder(c_ast.NodeVisitor):
def __init__(self, declname):
self.type = None
def visit_Decl(self, node):
self.decl = node
class TypeFinder(object):
def __init__(self, t):
self.type = t
self.result = []
def visit(self, node):
if node.__class__.__name__ == self.type.__name__:
self.result.append(node)
for c in node:
self.visit(c)
class TypedIdTransformer(AstPathTransformer):
def __init__(self, type_dict):
super().__init__()
self.type_dict = type_dict
def transform_ID(self, node, path):
#print("node[", node.name, "]")
#print("pstr[", node.name, "]", path_to_str(path))
try:
funcdef = next(filter(lambda x:isinstance(x, c_ast.FuncDef), path))
except:
funcdef = None
if funcdef is not None:
id_type = None
param_decl = funcdef.decl.type.args.params
#print("Param Decl", param_decl)
for decl in param_decl:
declfinder = TypeDeclFinder(node.name)
declfinder.visit(decl)
if declfinder.decl is not None:
id_type = self.type_dict.get(declfinder.decl.type.names[0], declfinder.decl.type.names[0])
break
if id_type is not None:
#print("pid[", node.name, "]", id_type)
return TypedID(node.name, id_type)
return node
def transform_StructRef(self, node, path):
if isinstance(node.name, c_ast.ID):
funcdef = next(filter(lambda x:isinstance(x, c_ast.FuncDef), path))
struct_type = None
param_decl = funcdef.decl.type.args.params
for decl in param_decl:
declfinder = TypeDeclFinder(node.name.name)
declfinder.visit(decl)
if declfinder.decl is not None:
struct_type = self.type_dict.get(declfinder.decl.type.names[0], None)
break
name_node = node.name
field_node = node.field
if struct_type is not None:
name_node = TypedID(node.name.name, struct_type)
# find member type
declfinder = DeclFinder(field_node.name)
declfinder.visit(struct_type)
tf = TypeFinder(c_ast.IdentifierType)
tf.visit(declfinder.decl.type.type)
idtype_str = " ".join(tf.result[0].names)
#print("idtype", idtype_str)
idtype = self.type_dict.get(idtype_str, idtype_str)
field_node = TypedID(node.field.name, idtype)
return c_ast.StructRef(name_node, node.type, field_node, node.coord)
name_node = self.transform(node.name, path)
field_node = self.transform(node.field, path)
struct_type = name_node.field.type
declfinder = DeclFinder(field_node.name)
declfinder.visit(struct_type)
tf = TypeFinder(c_ast.IdentifierType)
tf.visit(declfinder.decl.type.type)
idtype_str = " ".join(tf.result[0].names)
idtype = self.type_dict.get(idtype_str, idtype_str)
field_node = TypedID(node.field.name, idtype)
return c_ast.StructRef(name_node, node.type, field_node, node.coord)
class AstFuncCallInlinerTransformer(object):
def __init__(self, func_table):
self._method_cache = {}
self.func_table = func_table
def transform(self, node, idtable = None):
if idtable is None:
idtable = {}
visitor = self._method_cache.get(node.__class__.__name__, None)
if visitor is None:
method = 'transform_' + node.__class__.__name__
visitor = getattr(self, method, self._transform_generic)
self._method_cache[node.__class__.__name__] = visitor
return visitor(node, idtable)
def _transform_Node(self, node, idtable):
new_c = []
for c_name in node.__slots__[0:-1]:
c = getattr(node, c_name)
new_c.append(self.transform(c, idtable))
node_constructor = node.__class__
return node_constructor(*new_c)
def _transform_generic(self, node, idtable):
if isinstance(node, c_ast.Node):
return self._transform_Node(node, idtable)
elif isinstance(node, list):
return [self.transform(x, idtable) for x in node]
else:
return node
def transform_FuncCall(self, node, idtable):
fcall = node.name.name
if fcall in self.func_table:
funcdef_args = self.func_table[fcall].decl.type.args.params
fcall_args = node.args.exprs
if len(fcall_args) != len(funcdef_args):
raise Exception("Func Call does not match argument number")
# update ID table
idtable = idtable.copy()
for callarg,defarg in zip(fcall_args, funcdef_args):
tdec_finder = TypeFinder(c_ast.TypeDecl)
tdec_finder.visit(defarg)
if len(tdec_finder.result) == 1:
tdec = tdec_finder.result[0]
paramname = tdec.declname
idtable[paramname] = callarg
#print("idtable", idtable)
tfnode = self.transform(self.func_table[fcall], idtable)
return tfnode
return node
def transform_ID(self, node, idtable):
if node.name in idtable:
return idtable[node.name]
return node
def transform_TypedID(self, node, idtable):
return self.transform_ID(node, idtable)
def find_type_of_branch(node):
if isinstance(node, c_ast.StructRef):
return find_type_of_branch(node.field)
elif isinstance(node, TypedID):
return node.type
else:
pass
return None
class NuSmvConditionTransformer(AstTransformer):
def __init__(self):
super().__init__()
def transform_BinaryOp(self, node):
op = node.op
match op:
case '&&':
op = '&'
case '||':
op = '|'
case '==':
op = '='
case _:
op = op
lhs = node.left
rhs = node.right
lhs_type = find_type_of_branch(lhs)
rhs_type = find_type_of_branch(rhs)
#print("l type:", lhs_type)
#print("r type:", rhs_type)
#print("l node:", lhs)
#print("r node:", rhs)
lhs_is_bool = lhs_type == '_Bool'
rhs_is_bool = rhs_type == '_Bool'
lhs_is_constant = isinstance(lhs, c_ast.Constant)
rhs_is_constant = isinstance(rhs, c_ast.Constant)
if lhs_is_bool and rhs_is_constant:
rhs = c_ast.ID("FALSE" if rhs.value == '0' else "TRUE")
elif rhs_is_bool and lhs_is_constant:
lhs = c_ast.ID("FALSE" if lhs.value == '0' else "TRUE")
lhs = self.transform(lhs)
rhs = self.transform(rhs)
return c_ast.BinaryOp(op, lhs, rhs)
def transform_ID(self, node):
return c_ast.ID(node.name)
def transform_TypedID(self, node):
return c_ast.ID(node.name)
def transform_StructRef(self, node):
def srtf(node):
if isinstance(node, c_ast.StructRef):
return f"{srtf(node.name)}_{srtf(node.field)}"
elif isinstance(node, TypedID) or isinstance(node, c_ast.ID):
return node.name
return node
full = srtf(node)
return c_ast.ID(full)
def transform_Constant(self, node):
value = re.sub(r'([a-zA-Z]+)$', '', node.value)
return c_ast.Constant(node.type, value)
def transform_UnaryOp(self, node):
return c_ast.UnaryOp(node.op, self.transform(node.expr))
class NuSmvVariableExtractor(c_ast.NodeVisitor):
SHORT_TYPE = '32767..32767'
USHORT_TYPE = '0..65535'
INT_TYPE = '32767..32767'
UINT_TYPE = '0..65535'
TYPE_LUT = {
'_Bool':'boolean',
'char':'0..255',
'signed char':'-127..127',
'unsigned char':'0..255',
'short':SHORT_TYPE,
'short int':SHORT_TYPE,
'signed short':SHORT_TYPE,
'signed short int':SHORT_TYPE,
'unsigned short':USHORT_TYPE,
'unsigned short int':USHORT_TYPE,
'int':INT_TYPE,
'signed':INT_TYPE,
'signed int':INT_TYPE,
'unsigned':UINT_TYPE,
'unsigned int':UINT_TYPE
}
def __init__(self):
self.variables = {}
def visit_TypedID(self, node):
smvtype = self.TYPE_LUT.get(node.type, None)
if smvtype is None:
raise Exception(f"Type '{node.type}' is not supported")
self.variables[node.name] = smvtype
def visit_StructRef(self, node):
def findfield(node):
if isinstance(node, c_ast.StructRef):
return findfield(node.field)
return node
def srtf(node):
if isinstance(node, c_ast.StructRef):
return f"{srtf(node.name)}_{srtf(node.field)}"
elif isinstance(node, TypedID) or isinstance(node, c_ast.ID):
return node.name
return node
field = findfield(node)
full = srtf(node)
smvtype = self.TYPE_LUT.get(field.type, None)
if smvtype is None:
raise Exception(f"Type '{field.type}' is not supported")
self.variables[full] = smvtype