diff --git a/analyze.py b/analyze.py
index 8bb1a77..fdf6d13 100644
--- a/analyze.py
+++ b/analyze.py
@@ -4,6 +4,7 @@ import os,sys,io
from pycparser import parse_file, c_ast, CParser
from pcpp import Preprocessor
import graphviz as gv
+import html
from modelbuilder import *
from utils import *
@@ -56,6 +57,7 @@ if __name__ == "__main__":
#ast = parse_file(args.filename, use_cpp=False)
#ast.show()
+
initial_state = args.initial
assign_table = []
state_enums = []
@@ -63,28 +65,45 @@ if __name__ == "__main__":
enum_table = {}
state_asmts = []
fsm_funcs = []
- proc_func = None
fdv = FuncDefVisitor()
fdv.visit(ast)
func_table = fdv.func_table
+ proc_func = None
+ if not(args.func in func_table):
+ raise Exception(f"Function name '{args.func}' not found!")
+ else:
+ proc_func = func_table[args.func]
etv = EnumTypedefVisitor()
etv.visit(ast)
enum_table = etv.enums
states = []
- if not(args.func in func_table):
- raise Exception(f"Function name '{args.func}' not found!")
- else:
- proc_func = func_table[args.func]
- fsm_funcs.append(proc_func)
+
+ def discover_cals(func, visited=None):
+ if visited is None:
+ visited = []
+ funcs = []
+ child_funcs = []
fcv = FuncCallVisitor()
- fcv.visit(proc_func)
+ fcv.visit(func)
+ visited.append(proc_func)
for fc in fcv.func_calls:
if fc in func_table:
- fsm_funcs.append(func_table[fc])
- #fsm_funcs += [ func_table[x] for x in fcv.func_calls ]
+ funcs.append(func_table[fc])
+ child_funcs += discover_cals(func_table[fc], visited)
+ return funcs + child_funcs
+
+ # discover FSM functions
+ fsm_funcs.append(proc_func)
+ fsm_funcs += discover_cals(proc_func)
+ #fsm_funcs += [ func_table[x] for x in fcv.func_calls ]
+ print("Function Table")
+ for f in fsm_funcs:
+ print(f" - {f.decl.name} {'<<< entry' if (f.decl.name == args.func) else ''}")
+ print("")
+
print("Enum Table")
if args.enum in enum_table:
ev = EnumVisitor()
@@ -92,24 +111,25 @@ if __name__ == "__main__":
for ename in ev.enum_names:
print(" - ",ename)
states.append(ename)
- for f in fsm_funcs:
- sav = StateAssignmentVisitor(ename)
- sav.visit(f)
- state_asmts += sav.assignments
+ #for f in fsm_funcs:
+ sav = StateAssignmentVisitor(ast, ename)
+ #sav.visit(f)
+ sav.visit(proc_func)
+ state_asmts += sav.assignments
else:
print(f"Initial State Enum '{args.enum}' not found")
paths = []
for asm in state_asmts:
paths.append(asm[1])
-
+
#common = find_common_ancestor(paths)
tran_table = []
for sa in state_asmts:
for f in fsm_funcs:
- sctv = SwitchCaseTranVisitor(sa[0], states, sa[2], sa[3])
+ sctv = SwitchCaseTranVisitor(sa[0], sa[1], states, sa[2], sa[3], sa[4])
sctv.visit(f)
tran_table += sctv.tran_table
@@ -130,6 +150,7 @@ if __name__ == "__main__":
print("States: ", ",".join(states))
print("")
+ prop_func_blacklist = [f.decl.name for f in fsm_funcs]
reachable_states = [initial_state]
props_by_state = {}
print("Compact Transition Table:")
@@ -141,14 +162,12 @@ if __name__ == "__main__":
reachable_states.append(s)
# find properties
for f in fsm_funcs:
- sccpv = SwitchCaseCodePropertyVisitor(n, pure_sa)
+ sccpv = SwitchCaseCodePropertyVisitor(n, pure_sa, prop_func_blacklist)
sccpv.visit(f)
if len(sccpv.properties) > 0:
props_by_state[n] = sccpv.properties
print("")
- print("Reachable States")
-
states_by_property = {}
for state,props in props_by_state.items():
for prop in props:
@@ -160,7 +179,6 @@ if __name__ == "__main__":
properties = {}
property_alias = {}
for i,(prop,pstates) in enumerate(states_by_property.items()):
- print("foo ", i)
alias = base_10_to_alphabet(i + 1)
property_alias[prop] = alias
properties[alias] = (pstates, prop)
@@ -194,6 +212,7 @@ if __name__ == "__main__":
if args.dot is not None:
g = gv.Digraph('G')
+ # add states
for state in reachable_states:
state_short = state_to_short[state]
shape = 'oval'
@@ -205,9 +224,23 @@ if __name__ == "__main__":
g.node(state_short, label=f"{state_short}\n\n{{{pstr}}}", shape=shape)
else:
g.node(state_short, label=state_short, shape=shape)
-
-
+
+ # add transitions
for n,ms in comp_tt.items():
for m in ms:
g.edge(state_to_short[n], state_to_short[m])
+
+ # add property table
+ """
+ tabattr = 'border="0px"'
+ table_rows = []
+ for prop,alias in property_alias.items():
+ alias = html.escape(alias)
+ prop = html.escape(prop)
+ table_rows.append(f"
| {alias} | {prop} |
")
+ table_rows = ''.join(table_rows)
+ html = f""
+ print(html)
+ g.node("Properties", label=f"<{html}>",shape="none", rank="sink")
+ """
g.render(filename=args.dot)
diff --git a/astvisitors.py b/astvisitors.py
index 22ed012..4c524d8 100644
--- a/astvisitors.py
+++ b/astvisitors.py
@@ -32,6 +32,14 @@ def path_slice_back(node, p):
s.append(n)
return [x for x in reversed(s)]
+
+def path_filter(p, t):
+ elems = []
+ for n in p:
+ if isinstance(n, t):
+ elems.append(n)
+ return elems
+
class ExprListSerializerVisitor(c_ast.NodeVisitor):
def __init__(self):
self.serial = []
@@ -46,6 +54,14 @@ class ExprListSerializerVisitor(c_ast.NodeVisitor):
#todo: expand
+class CaseLabelExtractionVisitor(c_ast.NodeVisitor):
+ def __init__(self):
+ self.label = None
+
+ def visit_ID(self, node):
+ self.label = node.name
+
+
def expr_list_to_str(exprl):
elsv = ExprListSerializerVisitor()
elsv.visit(exprl)
@@ -72,6 +88,37 @@ class NodeVisitorWithParent(object):
self.visit(c)
self.current_parent = oldparent
+class NodeVisitorFuncCallForward(object):
+ def __init__(self, ast):
+ self.current_parent = None
+ self.ast = ast
+ fdv = FuncDefVisitor()
+ fdv.visit(ast)
+ self.func_table = fdv.func_table
+
+ def visit_FuncCall(self, node):
+ print("Visiting FuncCall")
+ print(node.show())
+ print('---- parent ----')
+ print(self.current_parent.show())
+
+ def visit(self, node):
+ """ Visit a node.
+ """
+ method = 'visit_' + node.__class__.__name__
+ visitor = getattr(self, method, self.generic_visit)
+ return visitor(node)
+
+ def generic_visit(self, node):
+ """ Called if no explicit visitor function exists for a
+ node. Implements preorder visiting of the node.
+ """
+ oldparent = self.current_parent
+ self.current_parent = node
+ for c in node.children():
+ self.visit(c)
+ self.current_parent = oldparent
+
class FuncDefVisitor(c_ast.NodeVisitor):
def __init__(self):
self.func_table = {}
@@ -86,17 +133,20 @@ class FuncCallVisitor(c_ast.NodeVisitor):
def visit_FuncCall(self, node):
self.func_calls.append(node.children()[0][1].name)
-class StateAssignmentVisitor(c_ast.NodeVisitor):
- def __init__(self, state):
- super().__init__()
+class StateAssignmentVisitor(NodeVisitorFuncCallForward):
+ def __init__(self, ast, state):
+ super().__init__(ast)
+ self._method_cache = {}
self.state = state
self.assignments = []
- def visit(self, node, path = None):
+ def visit(self, node, path = None, invariants = None):
""" Visit a node.
"""
if path is None:
path = []
+ if invariants is None:
+ invariants = []
if self._method_cache is None:
self._method_cache = {}
@@ -106,25 +156,44 @@ class StateAssignmentVisitor(c_ast.NodeVisitor):
visitor = getattr(self, method, self.generic_visit)
self._method_cache[node.__class__.__name__] = visitor
- return visitor(node, path)
-
- def generic_visit(self, node, path):
+ return visitor(node, path, invariants)
+
+ def generic_visit(self, node, path, invariants):
""" Called if no explicit visitor function exists for a
node. Implements preorder visiting of the node.
"""
path = path.copy()
path.append(node)
for c in node:
- self.visit(c, path)
+ self.visit(c, path, invariants)
- def visit_Assignment(self, n, path):
+ def visit_FuncCall(self, node, path, invariants):
+ fcall = node.name.name
+ #print("CAL path", path_to_str(path))
+ cases = path_filter(path, c_ast.Case)
+ #print(cases)
+ new_invariants = [x.expr.name for x in cases]
+ invariants = invariants.copy()
+ for ni in new_invariants:
+ if not(ni in invariants):
+ invariants.append(ni)
+ #print("invariants:", invariants)
+ #print(f"Visiting FuncCall {fcall}")
+ if fcall in self.func_table:
+ #print("->deferring!")
+ self.visit(self.func_table[fcall], path, invariants)
+ #print('---- path ----')
+
+ def visit_Assignment(self, n, path, invariants):
# fallthrough detection
case_node = path_select_last('Case', path)
fallthrough_case_names = []
if case_node is not None:
- case_parent = path_select_parent(case_node, path) #path_select_last('Compound', path)
+ case_parent = path_select_parent(case_node, path)
siblings = [x for x in case_parent.block_items if not isinstance(x, c_ast.Default)]
- sibling_names = [x.expr.name for x in siblings]
+ #for s in siblings:
+ # s.show()
+ sibling_names = [(x.expr.expr.name if isinstance(x.expr, c_ast.Cast) else x.expr.name) for x in siblings]
sibling_empty = [(len(x.stmts) == 0) for x in siblings]
in_fallthrough = False
@@ -153,10 +222,11 @@ class StateAssignmentVisitor(c_ast.NodeVisitor):
is_exhaustive_conditional = True if asm_if is None else (asm_if.iftrue is not None) and (asm_if.iffalse is not None)
is_conditional = path_contains(subpath_case, c_ast.If) and not(is_exhaustive_conditional)
- if not(isinstance(n.rvalue, c_ast.Constant) or isinstance(n.rvalue, c_ast.BinaryOp)):
+ if not(isinstance(n.rvalue, c_ast.Constant) or isinstance(n.rvalue, c_ast.BinaryOp) or isinstance(n.rvalue, c_ast.TernaryOp)):
rval_str = n.rvalue.name
+ #print(f">>>> {rval_str} == {self.state}")
if rval_str == self.state:
- self.assignments.append((n,path,fallthrough_case_names,is_conditional))
+ self.assignments.append((n,path,fallthrough_case_names,is_conditional,invariants))
class EnumDefVisitor(c_ast.NodeVisitor):
def __init__(self, name):
@@ -196,19 +266,30 @@ class SwitchCaseTermVisitor(c_ast.NodeVisitor):
self.hit = True
class SwitchCaseTranVisitor(c_ast.NodeVisitor):
- def __init__(self, asm_node, states, fallthrough_states, is_conditional):
+ def __init__(self, asm_node, path, states, fallthrough_states, is_conditional, invariants):
super().__init__()
self.states = states
+ self.path = path
+ self.invariants = invariants
self.fallthrough_states = fallthrough_states
self.is_conditional = is_conditional
self._asm_node = asm_node
self.tran_table = []
def visit_Case(self, node):
- state_from = node.children()[0][1].name
+ clev = CaseLabelExtractionVisitor()
+ clev.visit(node.children()[0][1])
+ state_from = clev.label
+ # highly inefficient but it is what it is
+ hit = False
+ #for n in self.path:
sctv = SwitchCaseTermVisitor(self._asm_node)
+ #node.show()
sctv.visit(node)
- if sctv.hit and (state_from in self.states):
+ hit = sctv.hit
+
+ #print(state_from, "->", self._asm_node.rvalue.name, "? hit=", hit, "isstate=", (state_from in self.states), "invar=", self.invariants)
+ if (hit or (state_from in self.invariants)) and (state_from in self.states):
#if conditional, state remains in state sometimes
if self.is_conditional:
self.tran_table.append((state_from,state_from))
@@ -218,8 +299,9 @@ class SwitchCaseTranVisitor(c_ast.NodeVisitor):
self.tran_table.append((ft, self._asm_node.rvalue.name))
class SwitchCasePropertyVisitor(c_ast.NodeVisitor):
- def __init__(self, state_asmts):
+ def __init__(self, state_asmts, func_blacklist = None):
super().__init__()
+ self._func_bl = func_blacklist
self._sas = state_asmts
self.properties = []
@@ -227,7 +309,8 @@ class SwitchCasePropertyVisitor(c_ast.NodeVisitor):
name = node.name.name
args = expr_list_to_str(node.args) if node.args is not None else ""
fcall = f"{name}({args})"
- self.properties.append(fcall)
+ if not(name in self._func_bl):
+ self.properties.append(fcall)
def visit_Assignment(self, node):
if not(node in self._sas):
@@ -248,8 +331,9 @@ class SwitchCasePropertyVisitor(c_ast.NodeVisitor):
class SwitchCaseCodePropertyVisitor(c_ast.NodeVisitor):
- def __init__(self, case, state_asmts):
+ def __init__(self, case, state_asmts, func_blacklist):
super().__init__()
+ self._func_bl = func_blacklist
self._case = case
self._sas = state_asmts
self.properties = []
@@ -257,8 +341,10 @@ class SwitchCaseCodePropertyVisitor(c_ast.NodeVisitor):
def visit_Case(self, node):
label = node.children()[0][1]
block = node
- if label.name == self._case:
- scpv = SwitchCasePropertyVisitor(self._sas)
+ clev = CaseLabelExtractionVisitor()
+ clev.visit(label)
+ if clev.label == self._case:
+ scpv = SwitchCasePropertyVisitor(self._sas, self._func_bl)
scpv.visit(block)
self.properties += scpv.properties