From 2a4c4b7660681392831c00cf17caf58f474c96d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dominic=20H=C3=B6glinger?= Date: Fri, 9 Dec 2022 17:44:06 +0100 Subject: [PATCH] now working on extracting a full mealy FSM from source code --- analyze.py | 280 +++++++-------- astvisitors.py | 922 ++++++++++++++++++++++++++++-------------------- modelbuilder.py | 39 +- 3 files changed, 717 insertions(+), 524 deletions(-) diff --git a/analyze.py b/analyze.py index 4d6be99..c95af24 100644 --- a/analyze.py +++ b/analyze.py @@ -1,15 +1,24 @@ from __future__ import print_function import argparse import os,sys,io -from pycparser import parse_file, c_ast, CParser +from pycparser import parse_file, c_ast, CParser, c_generator from pcpp import Preprocessor import graphviz as gv import html +from dataclasses import dataclass from modelbuilder import * from utils import * from astvisitors import * +def parse_statevar(statevar): + r_struct = r'(\w+)(\.|->)(\w+)' + match = re.match(r_struct, statevar) + if match: + name,typ,field = match.groups() + return c_ast.StructRef(c_ast.ID(name), typ, c_ast.ID(field)) + return c_ast.ID(statevar) + if __name__ == "__main__": argparser = argparse.ArgumentParser('Create a Kripke Structure Model from C Code') argparser.add_argument('filename', @@ -22,6 +31,7 @@ if __name__ == "__main__": argparser.add_argument('-c', '--conditional', help='only count state assignment if this conditional applies on the path') argparser.add_argument('--func', help='process function') argparser.add_argument('--enum', help='state enum') + argparser.add_argument('--statevar', help='state variable') argparser.add_argument('--initial', help='initial state') argparser.add_argument('--ltlfile', help='file containing LTL formulae') argparser.add_argument('-o', '--output', dest='output', help='output NuSMV file') @@ -36,29 +46,37 @@ if __name__ == "__main__": p = Preprocessor() p.add_path('/usr/lib/gcc/x86_64-linux-gnu/12/include/') for inc in args.includedirs: - print(f"Include-Dir: {inc}") + #print(f"Include-Dir: {inc}") p.add_path(inc) for define in args.defines: name,value = define.split('=') - print(f"Define: {name}={value}") + #print(f"Define: {name}={value}") p.define(f"{name} {value}") p.parse(source) oh = io.StringIO() p.write(oh) - prep_source = oh.getvalue()#unicode(oh.getvalue(), errors='ignore') + prep_source = oh.getvalue() if args.preprocess is not None: with open(args.preprocess, "wt") as f: n = f.write(prep_source) - + + cg = CGenerator() + parser = CParser() ast = parser.parse(prep_source) - #ast = parse_file(args.filename, use_cpp=False) - #ast.show() - + ass = parse_statevar(args.statevar) + tg = TypedefGatherer() + tg.visit(ast) + typedefs = tg.typedefs + #print("Typedefs", typedefs.keys()) + tast = TypedIdTransformer(typedefs).transform(ast) + #tast.show() + ast = tast + initial_state = args.initial assign_table = [] state_enums = [] @@ -66,6 +84,15 @@ if __name__ == "__main__": enum_table = {} state_asmts = [] fsm_funcs = [] + tran_table = [] + properties = [] + + state_id = 'm_tState' + + #TMR + #input_ids = ['m_bIn', 'm_ulActTime'] + input_ids = ['cbEnable', 'cbCondition', 'm_ucEnableUnconditioned'] + #conditional_ids = [*input_ids, 'm_ucTimerType'] fdv = FuncDefVisitor() fdv.visit(ast) @@ -75,24 +102,24 @@ if __name__ == "__main__": raise Exception(f"Function name '{args.func}' not found!") else: proc_func = func_table[args.func] + + fast = AstFuncCallInlinerTransformer(func_table).transform(ast) + ast = fast etv = EnumTypedefVisitor() etv.visit(ast) enum_table = etv.enums states = [] - def discover_cals(func, visited=None): - if visited is None: - visited = [] + def discover_cals(func): funcs = [] child_funcs = [] fcv = FuncCallVisitor() fcv.visit(func) - visited.append(proc_func) for fc in fcv.func_calls: if fc in func_table: funcs.append(func_table[fc]) - child_funcs += discover_cals(func_table[fc], visited) + child_funcs += discover_cals(func_table[fc]) return funcs + child_funcs # discover FSM functions @@ -139,136 +166,111 @@ if __name__ == "__main__": for ename in ev.enum_names: print(" - ",ename) states.append(ename) - #for f in fsm_funcs: - sav = StateAssignmentVisitor(ast, ename, cond_blacklist, cond_whitelist) - #sav.visit(f) - sav.visit(proc_func) - state_asmts += sav.assignments else: print(f"Initial State Enum '{args.enum}' not found") + print("") + + print("Extracting State Transitions") + sav = StateAssignmentVisitor(func_table, ass) + sav.visit(proc_func) + for assignment in sav.assignments: + state_to = assignment.state + cpa = ConditionalPathAnalyzer() + #print("Path", path_to_str(assignment.path)) + cpa.analyze(reversed(assignment.path)) + cond_chain = [] + conds_state = [] + state_from_expr = None + disregard_transition = False + for c in cpa.condition_chain: + # filter by conditional enum + # this is far from correct + cond_blacklist_visitor = ContainsOneOfIdVisitor(cond_blacklist) + cond_blacklist_visitor.visit(c) + if cond_blacklist_visitor.hit: + disregard_transition = True + break + conditional_visitor = ContainsOneOfIdVisitor(input_ids) + conditional_visitor.visit(c) + if conditional_visitor.hit: + cond_chain.append(c) + state_condition_visitor = ContainsOneOfIdVisitor([state_id]) + state_condition_visitor.visit(c) + if state_condition_visitor.hit: + conds_state.append(c) + + if not(disregard_transition): + if len(conds_state) != 1: + cond_exprs = [cg.visit(x) for x in cpa.condition_chain] + print("OOPS") + print(cond_exprs) + raise Exception("No or too many state conditions found") + + # find out from which state the assignment transitions + current_state_visitor = ContainsOneOfIdVisitor(states) + current_state_visitor.visit(conds_state[0]) + if not(current_state_visitor.hit): + raise Exception("State assignment does not assign state enum") + state_from = current_state_visitor.name + condition = None + if len(cond_chain) == 1: + condition = cond_chain[0] + elif len(cond_chain) > 1: + condition = cond_chain[0] + for expr in cond_chain[1:len(cond_chain)]: + condition = c_ast.BinaryOp('&&', expr, condition) + + cond_exprs = [cg.visit(x) for x in cond_chain] + cond_expr = cg.visit(condition) + + tf = NuSmvConditionTransformer() + cond_mod = tf.transform(condition) + mod_expr = cg.visit(cond_mod) - paths = [] - for asm in state_asmts: - paths.append(asm[1]) - - #common = find_common_ancestor(paths) - - tran_table = [] - - for sa in state_asmts: - for f in fsm_funcs: - sctv = SwitchCaseTranVisitor(sa[0], sa[1], states, sa[2], sa[3], sa[4]) - sctv.visit(f) - tran_table += sctv.tran_table - - comp_tt = {} - print("Transitions") - for t in tran_table: - print(f"{t[0]}->{t[1]}") - if t[0] in comp_tt: - if t[1] not in comp_tt[t[0]]: - comp_tt[t[0]].append(t[1]) - else: - comp_tt[t[0]] = [t[1]] + print(f" - {state_from} -> {state_to} on {mod_expr}") + tran_table.append(StateTransition(state_from, state_to, condition)) print("") + # Extract properties + av = AssignmentVisitor(func_table) + av.visit(proc_func) + assigments = [] + print("Assignments") + for a in av.assignments: + if a.node.op == '=': + if isinstance(a.node.lvalue, c_ast.StructRef): + if not(check_structref_equivalence(a.node.lvalue, ass)): + assigments.append(a) + + for a in assigments: + cpa = ConditionalPathAnalyzer() + #print("Path", path_to_str(assignment.path)) + cpa.analyze(reversed(a.path)) + cond_chain = [] + conds_state = [] + state_from_expr = None + disregard_transition = False + for c in cpa.condition_chain: + # filter by conditional enum + # this is far from correct + cond_blacklist_visitor = ContainsOneOfIdVisitor(cond_blacklist) + cond_blacklist_visitor.visit(c) + if cond_blacklist_visitor.hit: + disregard_transition = True + break + conditional_visitor = ContainsOneOfIdVisitor(input_ids) + conditional_visitor.visit(c) + if conditional_visitor.hit: + cond_chain.append(c) + state_condition_visitor = ContainsOneOfIdVisitor([state_id]) + state_condition_visitor.visit(c) + if state_condition_visitor.hit: + conds_state.append(c) - pure_sa = [x[0] for x in state_asmts] - #todo recomment once fixed - #states = comp_tt.keys() - print("States: ", ",".join(states)) - print("") - - prop_func_blacklist = [f.decl.name for f in fsm_funcs] - reachable_states = [initial_state] - props_by_state = {} - print("Compact Transition Table:") - for n,ms in comp_tt.items(): - sstr = ','.join(ms) - print(f"{n}->{{{sstr}}}") - for s in ms: - if not(s in reachable_states): - reachable_states.append(s) - # find properties - for f in fsm_funcs: - sccpv = SwitchCaseCodePropertyVisitor(n, states, prop_func_blacklist) - sccpv.visit(f) - if len(sccpv.properties) > 0: - props_by_state[n] = sccpv.properties - print("") + if not(disregard_transition): + cond_exprs = [cg.visit(x) for x in cond_chain] + state_cond_exprs = [cg.visit(x) for x in conds_state] + print(f" - {cg.visit(a.node)} {cond_exprs} {state_cond_exprs}") - states_by_property = {} - for state,props in props_by_state.items(): - for prop in props: - if prop in states_by_property: - states_by_property[prop].append(state) - else: - states_by_property[prop] = [state] - - properties = {} - property_alias = {} - for i,(prop,pstates) in enumerate(states_by_property.items()): - alias = base_10_to_alphabet(i + 1) - property_alias[prop] = alias - properties[alias] = (pstates, prop) - - print("Properties") - for prop,(pstates,full_prop) in properties.items(): - ss = ','.join(pstates) - print(f" - '{full_prop}' when {ss}") - print("") - - print("States shortform:") - states_prefix = os.path.commonprefix(states) - state_to_short = {} - for s in states: - state_to_short[s] = remove_prefix(s, states_prefix) - ltls = [] - if args.ltlfile is not None: - with open(args.ltlfile) as f: - ltls = [line.rstrip() for line in f] - - - mod = ModelBuilder(reachable_states, initial_state, comp_tt, properties, ltls=ltls) + mod = ModelBuilder(states, initial_state, tran_table, properties) nusmv = mod.generate() - - if args.output is not None: - with open(args.output, "wt") as f: - n = f.write(nusmv) - else: - print("-------------------") - print(nusmv) - - if args.dot is not None: - g = gv.Digraph('G') - # add states - for state in reachable_states: - state_short = state_to_short[state] - shape = 'oval' - if state == initial_state: - shape = 'doubleoctagon' - - if state in props_by_state: - pstr = ",".join([property_alias[x] for x in props_by_state[state]]) - g.node(state_short, label=f"{state_short}\n\n{{{pstr}}}", shape=shape) - else: - g.node(state_short, label=state_short, shape=shape) - - # add transitions - for n,ms in comp_tt.items(): - for m in ms: - g.edge(state_to_short[n], state_to_short[m]) - - # add property table - """ - tabattr = 'border="0px"' - table_rows = [] - for prop,alias in property_alias.items(): - alias = html.escape(alias) - prop = html.escape(prop) - table_rows.append(f"{alias}{prop}") - table_rows = ''.join(table_rows) - html = f"{table_rows}
" - print(html) - g.node("Properties", label=f"<{html}>",shape="none", rank="sink") - """ - g.render(filename=args.dot) + print(nusmv) diff --git a/astvisitors.py b/astvisitors.py index a6a75d3..bd7440e 100644 --- a/astvisitors.py +++ b/astvisitors.py @@ -1,133 +1,11 @@ import re from pycparser import c_ast, c_generator from itertools import pairwise +from dataclasses import dataclass def path_to_str(p): return "->".join([str(n.__class__.__name__) for n in p]) -def path_select_last(classname, p): - for n in reversed(p): - if n.__class__.__name__ == classname: - return n - return None - -def path_contains(p, t): - for n in p: - if isinstance(n, t): - return True - return False - -def path_select_parent(child, p): - parent = None - for n in p: - if n == child: - return parent - parent = n - return None - -def path_slice_back(node, p): - s = [] - for n in reversed(p): - if n == node: - break - s.append(n) - return [x for x in reversed(s)] - - -def path_filter(p, t): - elems = [] - for n in p: - if isinstance(n, t): - elems.append(n) - return elems - -def path_filter_multi(p, ts): - elems = [] - for n in p: - for t in ts: - if isinstance(n, t): - elems.append(n) - return elems - -class ExprListSerializerVisitor(c_ast.NodeVisitor): - def __init__(self): - self.serial = [] - - def visit_Constant(self, node): - expr = node.value - self.serial.append(expr) - - def visit_ID(self, node): - expr = node.name - self.serial.append(expr) - - #todo: expand - -class CaseLabelExtractionVisitor(c_ast.NodeVisitor): - def __init__(self): - self.label = None - - def visit_ID(self, node): - self.label = node.name - - -def expr_list_to_str(exprl): - elsv = ExprListSerializerVisitor() - elsv.visit(exprl) - return ','.join(elsv.serial) - -class NodeVisitorWithParent(object): - def __init__(self): - self.current_parent = None - - def visit(self, node): - """ Visit a node. - """ - method = 'visit_' + node.__class__.__name__ - visitor = getattr(self, method, self.generic_visit) - return visitor(node) - - def generic_visit(self, node): - """ Called if no explicit visitor function exists for a - node. Implements preorder visiting of the node. - """ - oldparent = self.current_parent - self.current_parent = node - for c in node.children(): - self.visit(c) - self.current_parent = oldparent - -class NodeVisitorFuncCallForward(object): - def __init__(self, ast): - self.current_parent = None - self.ast = ast - fdv = FuncDefVisitor() - fdv.visit(ast) - self.func_table = fdv.func_table - - def visit_FuncCall(self, node): - print("Visiting FuncCall") - print(node.show()) - print('---- parent ----') - print(self.current_parent.show()) - - def visit(self, node): - """ Visit a node. - """ - method = 'visit_' + node.__class__.__name__ - visitor = getattr(self, method, self.generic_visit) - return visitor(node) - - def generic_visit(self, node): - """ Called if no explicit visitor function exists for a - node. Implements preorder visiting of the node. - """ - oldparent = self.current_parent - self.current_parent = node - for c in node.children(): - self.visit(c) - self.current_parent = oldparent - class FuncDefVisitor(c_ast.NodeVisitor): def __init__(self): self.func_table = {} @@ -142,167 +20,6 @@ class FuncCallVisitor(c_ast.NodeVisitor): def visit_FuncCall(self, node): self.func_calls.append(node.children()[0][1].name) -class StateAssignmentVisitor(NodeVisitorFuncCallForward): - def __init__(self, ast, state, config_cond_blacklist=None, config_cond_whitelist=None): - super().__init__(ast) - self._method_cache = {} - self.state = state - self.ccb = config_cond_blacklist - self.ccw = config_cond_whitelist - self.assignments = [] - - def visit(self, node, path = None, invariants = None): - """ Visit a node. - """ - if path is None: - path = [] - if invariants is None: - invariants = [] - if self._method_cache is None: - self._method_cache = {} - - visitor = self._method_cache.get(node.__class__.__name__, None) - if visitor is None: - method = 'visit_' + node.__class__.__name__ - visitor = getattr(self, method, self.generic_visit) - self._method_cache[node.__class__.__name__] = visitor - - return visitor(node, path, invariants) - - def generic_visit(self, node, path, invariants): - """ Called if no explicit visitor function exists for a - node. Implements preorder visiting of the node. - """ - path = path.copy() - path.append(node) - for c in node: - self.visit(c, path, invariants) - - def visit_FuncCall(self, node, path, invariants): - fcall = node.name.name - #print("CAL path", path_to_str(path)) - cases = path_filter(path, c_ast.Case) - #print(cases) - #new_invariants = [x.expr.name for x in cases] - new_invariants = [] - for x in cases: - clev = CaseLabelExtractionVisitor() - clev.visit(x.expr) - new_invariants.append(clev.label) - - invariants = invariants.copy() - for ni in new_invariants: - if not(ni in invariants): - invariants.append(ni) - #print("invariants:", invariants) - #print(f"Visiting FuncCall {fcall}") - if fcall in self.func_table: - #print("->deferring!") - self.visit(self.func_table[fcall], path, invariants) - #print('---- path ----') - - def visit_Assignment(self, n, path, invariants): - # fallthrough detection - case_node = path_select_last('Case', path) - fallthrough_case_names = [] - if case_node is not None: - case_parent = path_select_parent(case_node, path) - siblings = [x for x in case_parent.block_items if not isinstance(x, c_ast.Default)] - #for s in siblings: - # s.show() - sibling_names = [(x.expr.expr.name if isinstance(x.expr, c_ast.Cast) else x.expr.name) for x in siblings] - sibling_empty = [(len(x.stmts) == 0) for x in siblings] - - in_fallthrough = False - slice_start = None - slice_end = None - for i,sibling in enumerate(zip(reversed(sibling_names),reversed(sibling_empty), reversed(siblings))): - if sibling[0] == self.state: - slice_start = i + 1 - if (slice_start is not None) and sibling[1] and not(in_fallthrough) and (slice_start == i): - in_fallthrough = True - if in_fallthrough: - if sibling[1]: - slice_end = i - else: - in_fallthrough = False - if (slice_start is not None) and (slice_end is not None): - slice_start_temp = slice_start - slice_start = len(siblings) - slice_end - slice_end = len(siblings) - slice_start_temp - fallthrough_case_names = sibling_names[slice_start-1:slice_end] - rval_str = '' - - # conditional assignment detection - asm_if = path_select_last('If', path) - asm_case = path_select_last('Case', path) - asm_ifs = path_filter(path, c_ast.If) - asm_condchain = path_filter_multi(path, [c_ast.If, c_ast.Case]) - asm_cases = path_filter(path, c_ast.Case) - subpath_case = path_slice_back(case_node, path) - is_exhaustive_conditional = True if asm_if is None else (asm_if.iftrue is not None) and (asm_if.iffalse is not None) - is_conditional = path_contains(subpath_case, c_ast.If) and not(is_exhaustive_conditional) - type_condition_antivalent = False - - # for ccn in asm_condchain: - # condition = None - # if isinstance(ccn, c_ast.If): - # condition = ccn.cond - # else: - # condition = ccn.expr - # cg = c_generator.CGenerator() - # expr = cg.visit(condition) - # print(f" - {expr}") - - #if asm_case is not None: - for case in asm_cases: - elsv = ExprListSerializerVisitor() - elsv.visit(case.expr) - if elsv.serial[0] in self.ccb: - type_condition_antivalent = True - break - - if not(type_condition_antivalent): - for if_node in asm_ifs: - cg = c_generator.CGenerator() - if_expr = cg.visit(if_node.cond) - for incl_cond in self.ccw: - match = re.search(f"(\!=|==)[\(\)\w\d\s]+{incl_cond}", if_expr) - if match is not None: - op = match.groups()[0] - if op == '!=': - type_condition_antivalent = True - break - - for excl_cond in self.ccb: - match = re.search(f"(\!=|==)[\(\)\w\d\s]+{excl_cond}", if_expr) - if match is not None: - op = match.groups()[0] - if op == '==': - type_condition_antivalent = True - break - - if type_condition_antivalent: - break - - if isinstance(n.rvalue, c_ast.TernaryOp): - # a ternary op -> we have to dissect it - cg = c_generator.CGenerator() - expr = cg.visit(n.rvalue) - elsv = ExprListSerializerVisitor() - elsv.visit(n.rvalue) - if (self.state in elsv.serial) and not(type_condition_antivalent): - self.assignments.append((n,path,fallthrough_case_names,is_conditional,invariants)) - print(f">>>> {self.state} ∈ {expr} antivalent={type_condition_antivalent}") - #n.rvalue.show() - #print(expr) - else: - if not(isinstance(n.rvalue, c_ast.Constant) or isinstance(n.rvalue, c_ast.BinaryOp)): - rval_str = n.rvalue.name - print(f">>>> {rval_str} == {self.state} antivalent={type_condition_antivalent}") - if (rval_str == self.state) and not(type_condition_antivalent): - self.assignments.append((n,path,fallthrough_case_names,is_conditional,invariants)) - class EnumDefVisitor(c_ast.NodeVisitor): def __init__(self, name): super().__init__() @@ -330,112 +47,567 @@ class EnumVisitor(c_ast.NodeVisitor): def visit_Enumerator(self, node): self.enum_names.append(node.name) -class SwitchCaseTermVisitor(c_ast.NodeVisitor): - def __init__(self, asm_node): +class PathVisitor(object): + def __init__(self): + self._method_cache = {} + + def visit(self, node, path = None): + if path is None: + path = [] + + visitor = self._method_cache.get(node.__class__.__name__, None) + if visitor is None: + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + self._method_cache[node.__class__.__name__] = visitor + + return visitor(node, path) + + def generic_visit(self, node, path): + path = path.copy() + path.append(node) + for c in node: + self.visit(c, path) + +class FuncCallDeferPathVisitor(PathVisitor): + def __init__(self, func_table): super().__init__() - self._asm_node = asm_node + self.func_table = func_table + + def visit_FuncCall(self, node, path): + fcall = node.name.name + if fcall in self.func_table: + self.visit(self.func_table[fcall], path) + +def check_node_equivalence(lhs, rhs): + return str(lhs) == str(rhs) + +def check_structref_equivalence(lhs:c_ast.StructRef, rhs:c_ast.StructRef): + if lhs.name.name != rhs.name.name: + return False + if lhs.field.name != rhs.field.name: + return False + return True + +@dataclass +class PathStateAssignment: + path:list[c_ast.Node] + node:c_ast.Assignment + state:str + +class StateAssignmentVisitor(FuncCallDeferPathVisitor): + def __init__(self, func_table, variable): + super().__init__(func_table) + self.variable = variable + self.assignments = [] + + def visit_Assignment(self, n, path): + if check_structref_equivalence(self.variable, n.lvalue): + if isinstance(n.rvalue, c_ast.ID): + self.assignments.append(PathStateAssignment(path, n, n.rvalue.name)) + elif isinstance(n.rvalue, c_ast.TernaryOp): + #print("TOP", n.rvalue) + pass + + +@dataclass +class PathAssignment: + path:list[c_ast.Node] + node:c_ast.Assignment + +class AssignmentVisitor(FuncCallDeferPathVisitor): + def __init__(self, func_table): + super().__init__(func_table) + self.assignments = [] + + def visit_Assignment(self, n, path): + self.assignments.append(PathAssignment(path, n)) + + +class PathAnalyzer(object): + def __init__(self): + self._method_cache = {} + + def analyze(self, path): + trace = [] + for node in path: + self._invoke_visitor(node, trace) + trace.append(node) + + def _invoke_visitor(self, node, trace): + visitor = self._method_cache.get(node.__class__.__name__, None) + if visitor is None: + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, None) + self._method_cache[node.__class__.__name__] = visitor + + if visitor is not None: + visitor(node, trace) + +@dataclass +class GeneralizedCase: + exprs:list[c_ast.Node] + stmts:list[c_ast.Node] + +class ConditionalPathAnalyzer(PathAnalyzer): + def __init__(self): + super().__init__() + self.condition_chain = [] + + def visit_If(self, node, trace): + p = trace[-1] + #print("If", node) + if node.iftrue == p: + #print("->iftrue") + self.condition_chain.append(node.cond) + elif node.iffalse == p: + #print("->iffalse") + neg = c_ast.UnaryOp('!', node.cond) + self.condition_chain.append(neg) + else: + pass + + def visit_Switch(self, node, trace): + p_case = trace[-2] + expr_fallthrough = [] + gencases = [] + cond = node.cond + block_items = node.stmt.block_items + parent_case_idx = None + #print("Switch", cond) + #print("P:", p_case) + #print("CI", block_items) + # lump together fallthrough cases + for case in block_items: + if p_case == case: + parent_case_idx = len(gencases) + if not isinstance(case, c_ast.Default): + expr_fallthrough.append(case.expr) + if len(case.stmts) != 0 and isinstance(case.stmts[-1], c_ast.Break): + gencases.append(GeneralizedCase(expr_fallthrough, case.stmts)) + expr_fallthrough = [] + #print("Generalized P-Case", gencases[parent_case_idx].exprs) + # does not account for default, which needs the entire set of cases + # checked for unequals and and-ed + eqops = [c_ast.BinaryOp('==', cond, expr) for expr in gencases[parent_case_idx].exprs] + if len(eqops) == 1: + self.condition_chain.append(eqops[0]) + elif len(eqops) > 1: + orop = eqops[0] + for expr in eqops[1:len(eqops)]: + orop = c_ast.BinaryOp('||', expr, orop) + self.condition_chain.append(orop) + else: + raise Exception("Unexpected number of expressions") + +class ContainsOneOfIdVisitor(c_ast.NodeVisitor): + def __init__(self, ids): + self.ids = ids self.hit = False + self.name = None - def visit_Assignment(self, node): - if node == self._asm_node: + def visit_ID(self, node): + if node.name in self.ids: self.hit = True + self.name = node.name -class SwitchCaseTranVisitor(c_ast.NodeVisitor): - def __init__(self, asm_node, path, states, fallthrough_states, is_conditional, invariants): + def visit_TypedID(self, node): + return self.visit_ID(node) + +@dataclass +class StateTransition: + state_from:str + state_to:str + condition:c_ast.Node + +class TypedefGatherer(c_ast.NodeVisitor): + def __init__(self): + self.typedefs = {} + + def visit_Typedef(self, node): + self.typedefs[node.name] = node + + +class AstTransformer(object): + def __init__(self): + self._method_cache = {} + + def transform(self, node): + visitor = self._method_cache.get(node.__class__.__name__, None) + if visitor is None: + method = 'transform_' + node.__class__.__name__ + visitor = getattr(self, method, self._transform_generic) + self._method_cache[node.__class__.__name__] = visitor + + return visitor(node) + + def _transform_Node(self, node): + new_c = [] + for c_name in node.__slots__[0:-1]: + c = getattr(node, c_name) + new_c.append(self.transform(c)) + node_constructor = node.__class__ + return node_constructor(*new_c) + + def _transform_generic(self, node): + if isinstance(node, c_ast.Node): + return self._transform_Node(node) + elif isinstance(node, list): + return [self.transform(x) for x in node] + else: + return node + +class AstPathTransformer(object): + def __init__(self): + self._method_cache = {} + + def transform(self, node, path = None): + if path is None: + path = [node] + else: + path = path.copy() + path.append(node) + + visitor = self._method_cache.get(node.__class__.__name__, None) + if visitor is None: + method = 'transform_' + node.__class__.__name__ + visitor = getattr(self, method, self._transform_generic) + self._method_cache[node.__class__.__name__] = visitor + + return visitor(node, path) + + def _transform_Node(self, node, path): + new_c = [] + for c_name in node.__slots__[0:-1]: + c = getattr(node, c_name) + new_c.append(self.transform(c, path)) + node_constructor = node.__class__ + return node_constructor(*new_c) + + def _transform_generic(self, node, path): + if isinstance(node, c_ast.Node): + return self._transform_Node(node, path) + elif isinstance(node, list): + return [self.transform(x, path) for x in node] + else: + return node + + + +class TypedID(c_ast.ID): + __slots__ = ('name', 'type', 'coord') + def __init__(self, name, type=None, coord=None): + super().__init__(name, coord) + self.type = type + + def children(self): + nodelist = [] + return tuple(nodelist) + + def __iter__(self): + return + yield + + attr_names = ('name', 'type' ) + +class CGenerator(c_generator.CGenerator): + def __init__(self): super().__init__() - self.states = states - self.path = path - self.invariants = invariants - self.fallthrough_states = fallthrough_states - self.is_conditional = is_conditional - self._asm_node = asm_node - self.tran_table = [] + + def visit_TypedID(self, n): + return n.name - def visit_Case(self, node): - clev = CaseLabelExtractionVisitor() - clev.visit(node.children()[0][1]) - state_from = clev.label - # highly inefficient but it is what it is - hit = False - #for n in self.path: - sctv = SwitchCaseTermVisitor(self._asm_node) - #node.show() - sctv.visit(node) - hit = sctv.hit - #print(state_from, "->", self._asm_node.rvalue.name, "? hit=", hit, "isstate=", (state_from in self.states), "invar=", self.invariants) - if (hit or (state_from in self.invariants)) and (state_from in self.states): - #if conditional, state remains in state sometimes - if self.is_conditional: - self.tran_table.append((state_from,state_from)) - # process state assignment - if isinstance(self._asm_node.rvalue, c_ast.TernaryOp): - elsv = ExprListSerializerVisitor() - elsv.visit(self._asm_node.rvalue) - ids = elsv.serial - for i in ids: - if i in self.states: - self.tran_table.append((state_from, i)) - for ft in self.fallthrough_states: - self.tran_table.append((ft, i)) - else: - self.tran_table.append((state_from, self._asm_node.rvalue.name)) - for ft in self.fallthrough_states: - self.tran_table.append((ft, self._asm_node.rvalue.name)) +class TypeDeclFinder(c_ast.NodeVisitor): + def __init__(self, declname): + self.declname = declname + self.decl = None -class SwitchCasePropertyVisitor(c_ast.NodeVisitor): - def __init__(self, states, func_blacklist = None): + def visit_TypeDecl(self, node): + if node.declname == self.declname: + self.decl = node + +class DeclFinder(c_ast.NodeVisitor): + def __init__(self, declname): + self.declname = declname + self.decl = None + + def visit_Decl(self, node): + if node.name == self.declname: + self.decl = node + +class IdentifierTypeFinder(c_ast.NodeVisitor): + def __init__(self, declname): + self.type = None + + def visit_Decl(self, node): + self.decl = node + + +class TypeFinder(object): + def __init__(self, t): + self.type = t + self.result = [] + + def visit(self, node): + if node.__class__.__name__ == self.type.__name__: + self.result.append(node) + + for c in node: + self.visit(c) + + +class TypedIdTransformer(AstPathTransformer): + def __init__(self, type_dict): super().__init__() - self._func_bl = func_blacklist - self.states = states - self.properties = [] + self.type_dict = type_dict - def visit_FuncCall(self, node): - name = node.name.name - args = expr_list_to_str(node.args) if node.args is not None else "" - fcall = f"{name}({args})" - if not(name in self._func_bl): - self.properties.append(fcall) - - def visit_Assignment(self, node): - #if not(node in self._sas): - cg = c_generator.CGenerator() - rvalue = cg.visit(node.rvalue) - print("NODE RVAL", rvalue) - if not(rvalue in self.states): - lvalue = None - rvalue = None - if isinstance(node.lvalue, c_ast.StructRef): - lvalue = f"{node.lvalue.children()[0][1].name}->{node.lvalue.children()[1][1].name}"; - else: - lvalue = node.lvalue.name + def transform_ID(self, node, path): + #print("node[", node.name, "]") + #print("pstr[", node.name, "]", path_to_str(path)) + try: + funcdef = next(filter(lambda x:isinstance(x, c_ast.FuncDef), path)) + except: + funcdef = None + if funcdef is not None: + id_type = None + param_decl = funcdef.decl.type.args.params + #print("Param Decl", param_decl) + for decl in param_decl: + declfinder = TypeDeclFinder(node.name) + declfinder.visit(decl) + if declfinder.decl is not None: + id_type = self.type_dict.get(declfinder.decl.type.names[0], declfinder.decl.type.names[0]) + break + if id_type is not None: + #print("pid[", node.name, "]", id_type) + return TypedID(node.name, id_type) - #if isinstance(node.rvalue, c_ast.Constant): - # rvalue = f"{node.rvalue.type }({node.rvalue.value})" - #else: - # rvalue = node.rvalue.name - cg = c_generator.CGenerator() - rvalue = cg.visit(node.rvalue) + return node + + def transform_StructRef(self, node, path): + if isinstance(node.name, c_ast.ID): + funcdef = next(filter(lambda x:isinstance(x, c_ast.FuncDef), path)) + struct_type = None + param_decl = funcdef.decl.type.args.params + for decl in param_decl: + declfinder = TypeDeclFinder(node.name.name) + declfinder.visit(decl) + if declfinder.decl is not None: + struct_type = self.type_dict.get(declfinder.decl.type.names[0], None) + break + name_node = node.name + field_node = node.field + if struct_type is not None: + name_node = TypedID(node.name.name, struct_type) + # find member type + declfinder = DeclFinder(field_node.name) + declfinder.visit(struct_type) + tf = TypeFinder(c_ast.IdentifierType) + tf.visit(declfinder.decl.type.type) + idtype_str = " ".join(tf.result[0].names) + #print("idtype", idtype_str) + idtype = self.type_dict.get(idtype_str, idtype_str) + field_node = TypedID(node.field.name, idtype) + + return c_ast.StructRef(name_node, node.type, field_node, node.coord) + + name_node = self.transform(node.name, path) + field_node = self.transform(node.field, path) + struct_type = name_node.field.type + declfinder = DeclFinder(field_node.name) + declfinder.visit(struct_type) + tf = TypeFinder(c_ast.IdentifierType) + tf.visit(declfinder.decl.type.type) + idtype_str = " ".join(tf.result[0].names) + idtype = self.type_dict.get(idtype_str, idtype_str) + field_node = TypedID(node.field.name, idtype) + + return c_ast.StructRef(name_node, node.type, field_node, node.coord) - prop = f"{lvalue}<={rvalue}" - self.properties.append(prop) +class AstFuncCallInlinerTransformer(object): + def __init__(self, func_table): + self._method_cache = {} + self.func_table = func_table + + def transform(self, node, idtable = None): + if idtable is None: + idtable = {} + + visitor = self._method_cache.get(node.__class__.__name__, None) + if visitor is None: + method = 'transform_' + node.__class__.__name__ + visitor = getattr(self, method, self._transform_generic) + self._method_cache[node.__class__.__name__] = visitor + + return visitor(node, idtable) + + def _transform_Node(self, node, idtable): + new_c = [] + for c_name in node.__slots__[0:-1]: + c = getattr(node, c_name) + new_c.append(self.transform(c, idtable)) + node_constructor = node.__class__ + return node_constructor(*new_c) + + def _transform_generic(self, node, idtable): + if isinstance(node, c_ast.Node): + return self._transform_Node(node, idtable) + elif isinstance(node, list): + return [self.transform(x, idtable) for x in node] + else: + return node + + def transform_FuncCall(self, node, idtable): + fcall = node.name.name + if fcall in self.func_table: + funcdef_args = self.func_table[fcall].decl.type.args.params + fcall_args = node.args.exprs + if len(fcall_args) != len(funcdef_args): + raise Exception("Func Call does not match argument number") + # update ID table + idtable = idtable.copy() + for callarg,defarg in zip(fcall_args, funcdef_args): + tdec_finder = TypeFinder(c_ast.TypeDecl) + tdec_finder.visit(defarg) + if len(tdec_finder.result) == 1: + tdec = tdec_finder.result[0] + paramname = tdec.declname + idtable[paramname] = callarg + #print("idtable", idtable) + tfnode = self.transform(self.func_table[fcall], idtable) + return tfnode + return node + + def transform_ID(self, node, idtable): + if node.name in idtable: + return idtable[node.name] + return node + + def transform_TypedID(self, node, idtable): + return self.transform_ID(node, idtable) -class SwitchCaseCodePropertyVisitor(c_ast.NodeVisitor): - def __init__(self, case, states, func_blacklist): +def find_type_of_branch(node): + if isinstance(node, c_ast.StructRef): + return find_type_of_branch(node.field) + elif isinstance(node, TypedID): + return node.type + else: + pass + return None + +class NuSmvConditionTransformer(AstTransformer): + def __init__(self): super().__init__() - self._func_bl = func_blacklist - self._case = case - self.states = states - self.properties = [] - def visit_Case(self, node): - label = node.children()[0][1] - block = node - clev = CaseLabelExtractionVisitor() - clev.visit(label) - if clev.label == self._case: - scpv = SwitchCasePropertyVisitor(self.states, self._func_bl) - scpv.visit(block) - self.properties += scpv.properties + def transform_BinaryOp(self, node): + op = node.op + match op: + case '&&': + op = '&' + case '||': + op = '|' + case '==': + op = '=' + case _: + op = op + + lhs = node.left + rhs = node.right + lhs_type = find_type_of_branch(lhs) + rhs_type = find_type_of_branch(rhs) + + #print("l type:", lhs_type) + #print("r type:", rhs_type) + #print("l node:", lhs) + #print("r node:", rhs) + lhs_is_bool = lhs_type == '_Bool' + rhs_is_bool = rhs_type == '_Bool' + lhs_is_constant = isinstance(lhs, c_ast.Constant) + rhs_is_constant = isinstance(rhs, c_ast.Constant) + if lhs_is_bool and rhs_is_constant: + rhs = c_ast.ID("FALSE" if rhs.value == '0' else "TRUE") + elif rhs_is_bool and lhs_is_constant: + lhs = c_ast.ID("FALSE" if lhs.value == '0' else "TRUE") + + lhs = self.transform(lhs) + rhs = self.transform(rhs) + + return c_ast.BinaryOp(op, lhs, rhs) + + def transform_ID(self, node): + return c_ast.ID(node.name) + + def transform_TypedID(self, node): + return c_ast.ID(node.name) + + def transform_StructRef(self, node): + def srtf(node): + if isinstance(node, c_ast.StructRef): + return f"{srtf(node.name)}_{srtf(node.field)}" + elif isinstance(node, TypedID) or isinstance(node, c_ast.ID): + return node.name + return node + + full = srtf(node) + return c_ast.ID(full) + + def transform_Constant(self, node): + value = re.sub(r'([a-zA-Z]+)$', '', node.value) + return c_ast.Constant(node.type, value) + + def transform_UnaryOp(self, node): + return c_ast.UnaryOp(node.op, self.transform(node.expr)) + +class NuSmvVariableExtractor(c_ast.NodeVisitor): + SHORT_TYPE = '−32767..32767' + USHORT_TYPE = '0..65535' + INT_TYPE = '−32767..32767' + UINT_TYPE = '0..65535' + TYPE_LUT = { + '_Bool':'boolean', + 'char':'0..255', + 'signed char':'-127..127', + 'unsigned char':'0..255', + 'short':SHORT_TYPE, + 'short int':SHORT_TYPE, + 'signed short':SHORT_TYPE, + 'signed short int':SHORT_TYPE, + 'unsigned short':USHORT_TYPE, + 'unsigned short int':USHORT_TYPE, + 'int':INT_TYPE, + 'signed':INT_TYPE, + 'signed int':INT_TYPE, + 'unsigned':UINT_TYPE, + 'unsigned int':UINT_TYPE + } + def __init__(self): + self.variables = {} + + def visit_TypedID(self, node): + smvtype = self.TYPE_LUT.get(node.type, None) + if smvtype is None: + raise Exception(f"Type '{node.type}' is not supported") + self.variables[node.name] = smvtype + + def visit_StructRef(self, node): + def findfield(node): + if isinstance(node, c_ast.StructRef): + return findfield(node.field) + return node + + def srtf(node): + if isinstance(node, c_ast.StructRef): + return f"{srtf(node.name)}_{srtf(node.field)}" + elif isinstance(node, TypedID) or isinstance(node, c_ast.ID): + return node.name + return node + field = findfield(node) + full = srtf(node) + smvtype = self.TYPE_LUT.get(field.type, None) + if smvtype is None: + raise Exception(f"Type '{field.type}' is not supported") + self.variables[full] = smvtype + diff --git a/modelbuilder.py b/modelbuilder.py index 13bec19..91bcfd0 100644 --- a/modelbuilder.py +++ b/modelbuilder.py @@ -1,3 +1,4 @@ +from astvisitors import * MODEL = """ ------------------------------------------------------------------------ @@ -5,18 +6,22 @@ MODEL = """ ------------------------------------------------------------------------ MODULE {name} VAR -state : {{ {states} }}; + state : {{ {states} }}; +{variables} ASSIGN init(state) := {initial}; next(state) := case {transitions} + TRUE : state; esac; DEFINE {properties} """ -TRAN = " state = {n} : {{{ms}}};" +VARIABLE = " {n}:{t};" + +TRAN = " (state = {n}) & ({cond}) : {m};" PROP = """ -- Property "{prop}" {alias} := {logic}; """ @@ -36,22 +41,36 @@ class ModelBuilder: self._ltls = ltls def generate(self): + cg = c_generator.CGenerator() + cond_tf = NuSmvConditionTransformer() + # build model states_decl = ",".join(self._states) transitions = [] - - for n,ms in self._tran.items(): - transition = TRAN.format(n=n, ms=",".join(ms)) + + # find variables in the condition + varextract = NuSmvVariableExtractor() + + for tran in self._tran: + cond = cond_tf.transform(tran.condition) + expr = cg.visit(cond) + transition = TRAN.format(n=tran.state_from, m=tran.state_to, cond=expr) transitions.append(transition) - + varextract.visit(tran.condition) + + variables = [] + for v,t in varextract.variables.items(): + variables.append(VARIABLE.format(n=v, t=t)) + properties = [] - for alias,(states,prop) in self._props.items(): - logic = " | ".join([PROP_LOGIC.format(state=x) for x in states]) - prop_str = PROP.format(prop=prop, alias=alias, logic=logic) - properties.append(prop_str) + #for alias,(states,prop) in self._props.items(): + # logic = " | ".join([PROP_LOGIC.format(state=x) for x in states]) + # prop_str = PROP.format(prop=prop, alias=alias, logic=logic) + # properties.append(prop_str) out = MODEL.format(name=self._name, states=states_decl, + variables="\n".join(variables), initial=self._initial, transitions="\n".join(transitions), properties="\n".join(properties))