From 57cb073ff42f633d3db2985f31388a165c6a185e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dominic=20H=C3=B6glinger?= Date: Tue, 22 Nov 2022 18:25:30 +0100 Subject: [PATCH] added function deferring with invariants (states that hold true) in the execution path or something, I'm no prof good news is that state processing functions should work --- analyze.py | 75 +++++++++++++++++++++-------- astvisitors.py | 128 +++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 161 insertions(+), 42 deletions(-) diff --git a/analyze.py b/analyze.py index 8bb1a77..fdf6d13 100644 --- a/analyze.py +++ b/analyze.py @@ -4,6 +4,7 @@ import os,sys,io from pycparser import parse_file, c_ast, CParser from pcpp import Preprocessor import graphviz as gv +import html from modelbuilder import * from utils import * @@ -56,6 +57,7 @@ if __name__ == "__main__": #ast = parse_file(args.filename, use_cpp=False) #ast.show() + initial_state = args.initial assign_table = [] state_enums = [] @@ -63,28 +65,45 @@ if __name__ == "__main__": enum_table = {} state_asmts = [] fsm_funcs = [] - proc_func = None fdv = FuncDefVisitor() fdv.visit(ast) func_table = fdv.func_table + proc_func = None + if not(args.func in func_table): + raise Exception(f"Function name '{args.func}' not found!") + else: + proc_func = func_table[args.func] etv = EnumTypedefVisitor() etv.visit(ast) enum_table = etv.enums states = [] - if not(args.func in func_table): - raise Exception(f"Function name '{args.func}' not found!") - else: - proc_func = func_table[args.func] - fsm_funcs.append(proc_func) + + def discover_cals(func, visited=None): + if visited is None: + visited = [] + funcs = [] + child_funcs = [] fcv = FuncCallVisitor() - fcv.visit(proc_func) + fcv.visit(func) + visited.append(proc_func) for fc in fcv.func_calls: if fc in func_table: - fsm_funcs.append(func_table[fc]) - #fsm_funcs += [ func_table[x] for x in fcv.func_calls ] + funcs.append(func_table[fc]) + child_funcs += discover_cals(func_table[fc], visited) + return funcs + child_funcs + + # discover FSM functions + fsm_funcs.append(proc_func) + fsm_funcs += discover_cals(proc_func) + #fsm_funcs += [ func_table[x] for x in fcv.func_calls ] + print("Function Table") + for f in fsm_funcs: + print(f" - {f.decl.name} {'<<< entry' if (f.decl.name == args.func) else ''}") + print("") + print("Enum Table") if args.enum in enum_table: ev = EnumVisitor() @@ -92,24 +111,25 @@ if __name__ == "__main__": for ename in ev.enum_names: print(" - ",ename) states.append(ename) - for f in fsm_funcs: - sav = StateAssignmentVisitor(ename) - sav.visit(f) - state_asmts += sav.assignments + #for f in fsm_funcs: + sav = StateAssignmentVisitor(ast, ename) + #sav.visit(f) + sav.visit(proc_func) + state_asmts += sav.assignments else: print(f"Initial State Enum '{args.enum}' not found") paths = [] for asm in state_asmts: paths.append(asm[1]) - + #common = find_common_ancestor(paths) tran_table = [] for sa in state_asmts: for f in fsm_funcs: - sctv = SwitchCaseTranVisitor(sa[0], states, sa[2], sa[3]) + sctv = SwitchCaseTranVisitor(sa[0], sa[1], states, sa[2], sa[3], sa[4]) sctv.visit(f) tran_table += sctv.tran_table @@ -130,6 +150,7 @@ if __name__ == "__main__": print("States: ", ",".join(states)) print("") + prop_func_blacklist = [f.decl.name for f in fsm_funcs] reachable_states = [initial_state] props_by_state = {} print("Compact Transition Table:") @@ -141,14 +162,12 @@ if __name__ == "__main__": reachable_states.append(s) # find properties for f in fsm_funcs: - sccpv = SwitchCaseCodePropertyVisitor(n, pure_sa) + sccpv = SwitchCaseCodePropertyVisitor(n, pure_sa, prop_func_blacklist) sccpv.visit(f) if len(sccpv.properties) > 0: props_by_state[n] = sccpv.properties print("") - print("Reachable States") - states_by_property = {} for state,props in props_by_state.items(): for prop in props: @@ -160,7 +179,6 @@ if __name__ == "__main__": properties = {} property_alias = {} for i,(prop,pstates) in enumerate(states_by_property.items()): - print("foo ", i) alias = base_10_to_alphabet(i + 1) property_alias[prop] = alias properties[alias] = (pstates, prop) @@ -194,6 +212,7 @@ if __name__ == "__main__": if args.dot is not None: g = gv.Digraph('G') + # add states for state in reachable_states: state_short = state_to_short[state] shape = 'oval' @@ -205,9 +224,23 @@ if __name__ == "__main__": g.node(state_short, label=f"{state_short}\n\n{{{pstr}}}", shape=shape) else: g.node(state_short, label=state_short, shape=shape) - - + + # add transitions for n,ms in comp_tt.items(): for m in ms: g.edge(state_to_short[n], state_to_short[m]) + + # add property table + """ + tabattr = 'border="0px"' + table_rows = [] + for prop,alias in property_alias.items(): + alias = html.escape(alias) + prop = html.escape(prop) + table_rows.append(f"{alias}{prop}") + table_rows = ''.join(table_rows) + html = f"{table_rows}
" + print(html) + g.node("Properties", label=f"<{html}>",shape="none", rank="sink") + """ g.render(filename=args.dot) diff --git a/astvisitors.py b/astvisitors.py index 22ed012..4c524d8 100644 --- a/astvisitors.py +++ b/astvisitors.py @@ -32,6 +32,14 @@ def path_slice_back(node, p): s.append(n) return [x for x in reversed(s)] + +def path_filter(p, t): + elems = [] + for n in p: + if isinstance(n, t): + elems.append(n) + return elems + class ExprListSerializerVisitor(c_ast.NodeVisitor): def __init__(self): self.serial = [] @@ -46,6 +54,14 @@ class ExprListSerializerVisitor(c_ast.NodeVisitor): #todo: expand +class CaseLabelExtractionVisitor(c_ast.NodeVisitor): + def __init__(self): + self.label = None + + def visit_ID(self, node): + self.label = node.name + + def expr_list_to_str(exprl): elsv = ExprListSerializerVisitor() elsv.visit(exprl) @@ -72,6 +88,37 @@ class NodeVisitorWithParent(object): self.visit(c) self.current_parent = oldparent +class NodeVisitorFuncCallForward(object): + def __init__(self, ast): + self.current_parent = None + self.ast = ast + fdv = FuncDefVisitor() + fdv.visit(ast) + self.func_table = fdv.func_table + + def visit_FuncCall(self, node): + print("Visiting FuncCall") + print(node.show()) + print('---- parent ----') + print(self.current_parent.show()) + + def visit(self, node): + """ Visit a node. + """ + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + return visitor(node) + + def generic_visit(self, node): + """ Called if no explicit visitor function exists for a + node. Implements preorder visiting of the node. + """ + oldparent = self.current_parent + self.current_parent = node + for c in node.children(): + self.visit(c) + self.current_parent = oldparent + class FuncDefVisitor(c_ast.NodeVisitor): def __init__(self): self.func_table = {} @@ -86,17 +133,20 @@ class FuncCallVisitor(c_ast.NodeVisitor): def visit_FuncCall(self, node): self.func_calls.append(node.children()[0][1].name) -class StateAssignmentVisitor(c_ast.NodeVisitor): - def __init__(self, state): - super().__init__() +class StateAssignmentVisitor(NodeVisitorFuncCallForward): + def __init__(self, ast, state): + super().__init__(ast) + self._method_cache = {} self.state = state self.assignments = [] - def visit(self, node, path = None): + def visit(self, node, path = None, invariants = None): """ Visit a node. """ if path is None: path = [] + if invariants is None: + invariants = [] if self._method_cache is None: self._method_cache = {} @@ -106,25 +156,44 @@ class StateAssignmentVisitor(c_ast.NodeVisitor): visitor = getattr(self, method, self.generic_visit) self._method_cache[node.__class__.__name__] = visitor - return visitor(node, path) - - def generic_visit(self, node, path): + return visitor(node, path, invariants) + + def generic_visit(self, node, path, invariants): """ Called if no explicit visitor function exists for a node. Implements preorder visiting of the node. """ path = path.copy() path.append(node) for c in node: - self.visit(c, path) + self.visit(c, path, invariants) - def visit_Assignment(self, n, path): + def visit_FuncCall(self, node, path, invariants): + fcall = node.name.name + #print("CAL path", path_to_str(path)) + cases = path_filter(path, c_ast.Case) + #print(cases) + new_invariants = [x.expr.name for x in cases] + invariants = invariants.copy() + for ni in new_invariants: + if not(ni in invariants): + invariants.append(ni) + #print("invariants:", invariants) + #print(f"Visiting FuncCall {fcall}") + if fcall in self.func_table: + #print("->deferring!") + self.visit(self.func_table[fcall], path, invariants) + #print('---- path ----') + + def visit_Assignment(self, n, path, invariants): # fallthrough detection case_node = path_select_last('Case', path) fallthrough_case_names = [] if case_node is not None: - case_parent = path_select_parent(case_node, path) #path_select_last('Compound', path) + case_parent = path_select_parent(case_node, path) siblings = [x for x in case_parent.block_items if not isinstance(x, c_ast.Default)] - sibling_names = [x.expr.name for x in siblings] + #for s in siblings: + # s.show() + sibling_names = [(x.expr.expr.name if isinstance(x.expr, c_ast.Cast) else x.expr.name) for x in siblings] sibling_empty = [(len(x.stmts) == 0) for x in siblings] in_fallthrough = False @@ -153,10 +222,11 @@ class StateAssignmentVisitor(c_ast.NodeVisitor): is_exhaustive_conditional = True if asm_if is None else (asm_if.iftrue is not None) and (asm_if.iffalse is not None) is_conditional = path_contains(subpath_case, c_ast.If) and not(is_exhaustive_conditional) - if not(isinstance(n.rvalue, c_ast.Constant) or isinstance(n.rvalue, c_ast.BinaryOp)): + if not(isinstance(n.rvalue, c_ast.Constant) or isinstance(n.rvalue, c_ast.BinaryOp) or isinstance(n.rvalue, c_ast.TernaryOp)): rval_str = n.rvalue.name + #print(f">>>> {rval_str} == {self.state}") if rval_str == self.state: - self.assignments.append((n,path,fallthrough_case_names,is_conditional)) + self.assignments.append((n,path,fallthrough_case_names,is_conditional,invariants)) class EnumDefVisitor(c_ast.NodeVisitor): def __init__(self, name): @@ -196,19 +266,30 @@ class SwitchCaseTermVisitor(c_ast.NodeVisitor): self.hit = True class SwitchCaseTranVisitor(c_ast.NodeVisitor): - def __init__(self, asm_node, states, fallthrough_states, is_conditional): + def __init__(self, asm_node, path, states, fallthrough_states, is_conditional, invariants): super().__init__() self.states = states + self.path = path + self.invariants = invariants self.fallthrough_states = fallthrough_states self.is_conditional = is_conditional self._asm_node = asm_node self.tran_table = [] def visit_Case(self, node): - state_from = node.children()[0][1].name + clev = CaseLabelExtractionVisitor() + clev.visit(node.children()[0][1]) + state_from = clev.label + # highly inefficient but it is what it is + hit = False + #for n in self.path: sctv = SwitchCaseTermVisitor(self._asm_node) + #node.show() sctv.visit(node) - if sctv.hit and (state_from in self.states): + hit = sctv.hit + + #print(state_from, "->", self._asm_node.rvalue.name, "? hit=", hit, "isstate=", (state_from in self.states), "invar=", self.invariants) + if (hit or (state_from in self.invariants)) and (state_from in self.states): #if conditional, state remains in state sometimes if self.is_conditional: self.tran_table.append((state_from,state_from)) @@ -218,8 +299,9 @@ class SwitchCaseTranVisitor(c_ast.NodeVisitor): self.tran_table.append((ft, self._asm_node.rvalue.name)) class SwitchCasePropertyVisitor(c_ast.NodeVisitor): - def __init__(self, state_asmts): + def __init__(self, state_asmts, func_blacklist = None): super().__init__() + self._func_bl = func_blacklist self._sas = state_asmts self.properties = [] @@ -227,7 +309,8 @@ class SwitchCasePropertyVisitor(c_ast.NodeVisitor): name = node.name.name args = expr_list_to_str(node.args) if node.args is not None else "" fcall = f"{name}({args})" - self.properties.append(fcall) + if not(name in self._func_bl): + self.properties.append(fcall) def visit_Assignment(self, node): if not(node in self._sas): @@ -248,8 +331,9 @@ class SwitchCasePropertyVisitor(c_ast.NodeVisitor): class SwitchCaseCodePropertyVisitor(c_ast.NodeVisitor): - def __init__(self, case, state_asmts): + def __init__(self, case, state_asmts, func_blacklist): super().__init__() + self._func_bl = func_blacklist self._case = case self._sas = state_asmts self.properties = [] @@ -257,8 +341,10 @@ class SwitchCaseCodePropertyVisitor(c_ast.NodeVisitor): def visit_Case(self, node): label = node.children()[0][1] block = node - if label.name == self._case: - scpv = SwitchCasePropertyVisitor(self._sas) + clev = CaseLabelExtractionVisitor() + clev.visit(label) + if clev.label == self._case: + scpv = SwitchCasePropertyVisitor(self._sas, self._func_bl) scpv.visit(block) self.properties += scpv.properties