diff --git a/astvisitors.py b/astvisitors.py index a8621b0..53a9ed4 100644 --- a/astvisitors.py +++ b/astvisitors.py @@ -41,6 +41,14 @@ def path_filter(p, t): elems.append(n) return elems +def path_filter_multi(p, ts): + elems = [] + for n in p: + for t in ts: + if isinstance(n, t): + elems.append(n) + return elems + class ExprListSerializerVisitor(c_ast.NodeVisitor): def __init__(self): self.serial = [] @@ -229,16 +237,27 @@ class StateAssignmentVisitor(NodeVisitorFuncCallForward): asm_if = path_select_last('If', path) asm_case = path_select_last('Case', path) asm_ifs = path_filter(path, c_ast.If) + asm_condchain = path_filter_multi(path, [c_ast.If, c_ast.Case]) asm_cases = path_filter(path, c_ast.Case) subpath_case = path_slice_back(case_node, path) is_exhaustive_conditional = True if asm_if is None else (asm_if.iftrue is not None) and (asm_if.iffalse is not None) is_conditional = path_contains(subpath_case, c_ast.If) and not(is_exhaustive_conditional) type_condition_antivalent = False + + # for ccn in asm_condchain: + # condition = None + # if isinstance(ccn, c_ast.If): + # condition = ccn.cond + # else: + # condition = ccn.expr + # cg = c_generator.CGenerator() + # expr = cg.visit(condition) + # print(f" - {expr}") + #if asm_case is not None: for case in asm_cases: elsv = ExprListSerializerVisitor() elsv.visit(case.expr) - print("big case", elsv.serial[0]) if elsv.serial[0] in self.ccb: type_condition_antivalent = True break @@ -247,7 +266,6 @@ class StateAssignmentVisitor(NodeVisitorFuncCallForward): for if_node in asm_ifs: cg = c_generator.CGenerator() if_expr = cg.visit(if_node.cond) - print("big if", if_expr) for incl_cond in self.ccw: match = re.search(f"(\!=|==)[\(\)\w\d\s]+{incl_cond}", if_expr) if match is not None: @@ -267,12 +285,23 @@ class StateAssignmentVisitor(NodeVisitorFuncCallForward): if type_condition_antivalent: break - - if not(isinstance(n.rvalue, c_ast.Constant) or isinstance(n.rvalue, c_ast.BinaryOp) or isinstance(n.rvalue, c_ast.TernaryOp)): - rval_str = n.rvalue.name - print(f">>>> {rval_str} == {self.state} antivalent={type_condition_antivalent}") - if (rval_str == self.state) and not(type_condition_antivalent): - self.assignments.append((n,path,fallthrough_case_names,is_conditional,invariants)) + if isinstance(n.rvalue, c_ast.TernaryOp): + # a ternary op -> we have to dissect it + cg = c_generator.CGenerator() + expr = cg.visit(n.rvalue) + elsv = ExprListSerializerVisitor() + elsv.visit(n.rvalue) + if (self.state in elsv.serial) and not(type_condition_antivalent): + self.assignments.append((n,path,fallthrough_case_names,is_conditional,invariants)) + print(f">>>> {self.state} ∈ {expr} antivalent={type_condition_antivalent}") + #n.rvalue.show() + #print(expr) + else: + if not(isinstance(n.rvalue, c_ast.Constant) or isinstance(n.rvalue, c_ast.BinaryOp)): + rval_str = n.rvalue.name + print(f">>>> {rval_str} == {self.state} antivalent={type_condition_antivalent}") + if (rval_str == self.state) and not(type_condition_antivalent): + self.assignments.append((n,path,fallthrough_case_names,is_conditional,invariants)) class EnumDefVisitor(c_ast.NodeVisitor): def __init__(self, name): @@ -340,9 +369,19 @@ class SwitchCaseTranVisitor(c_ast.NodeVisitor): if self.is_conditional: self.tran_table.append((state_from,state_from)) # process state assignment - self.tran_table.append((state_from, self._asm_node.rvalue.name)) - for ft in self.fallthrough_states: - self.tran_table.append((ft, self._asm_node.rvalue.name)) + if isinstance(self._asm_node.rvalue, c_ast.TernaryOp): + elsv = ExprListSerializerVisitor() + elsv.visit(self._asm_node.rvalue) + ids = elsv.serial + for i in ids: + if i in self.states: + self.tran_table.append((state_from, i)) + for ft in self.fallthrough_states: + self.tran_table.append((ft, i)) + else: + self.tran_table.append((state_from, self._asm_node.rvalue.name)) + for ft in self.fallthrough_states: + self.tran_table.append((ft, self._asm_node.rvalue.name)) class SwitchCasePropertyVisitor(c_ast.NodeVisitor): def __init__(self, state_asmts, func_blacklist = None):