import os.path

from api.od import ODAPI
from bootstrap.scd import bootstrap_scd
from concrete_syntax.textual_od.parser import parse_od
from state.devstate import DevState
from transformation.matcher import match_od
from transformation.ramify import ramify
from transformation.rule import RuleMatcherRewriter, PriorityActionGenerator
from util.loader import parse_and_check, load_rules
from models import factory_mm, factory_rt_mm, small_factory_m, small_factory_rt_m
from simulator import FactorySimulator


class TerminationCondition:
    def __init__(self, state, rt_mm):
        self.state = state
        self.rt_mm_ramified = ramify(state, rt_mm)

        patterns_cs = {
            "There are at least two items accepted": """
                anAccept:RAM_StoreAccept {
                    condition = `len(get_incoming(this, "stored")) >= 2`;
                }
                """,
            "There is a rejected item": """
                aReject:RAM_StoreReject {
                    condition = `len(get_incoming(this, "stored")) >= 1`;
                }
                """
        }

        self.patterns = {cause: parse_od(state, pattern_cs, self.rt_mm_ramified)
                         for cause, pattern_cs in patterns_cs.items()}

    def __call__(self, od):
        for cause in self.patterns:
            for match in match_od(self.state, od.m, od.mm, self.patterns[cause], self.rt_mm_ramified):
                # stop after the first match (no need to look for more matches):
                return cause


def render_text(od: ODAPI):
    txt = ""
    for _, station in od.get_all_instances("WorkStation"):
        if od.get_incoming(station, "queueIn"):
            txt += "🏭 "
            for q in od.get_incoming(station, "queueIn"):
                queue_obj = od.get_source(q)
                capacity = od.get_slot_value(queue_obj, "capacity")
                filled = len(od.get_incoming(queue_obj, "inQueue"))
                txt += f"{od.get_name(queue_obj)} ({'⚙️'*filled}{'🔲'*(capacity-filled)}) "
            txt += "🏭 --> "
        txt += f"{od.get_name(station)}"
        if od.get_outgoing(station, "stateOf"):
            station_state = od.get_target(od.get_outgoing(station, "stateOf")[0])
            symbol = "🆓"
            status = od.get_slot_value(station_state, 'processing')
            if status == "processing":
                symbol = "☑️"
            elif status == "processed":
                symbol = "✅"
            txt += f" [{symbol}]"

        if od.get_incoming(station, "workingOn"):
            worker = od.get_incoming(station, "workingOn")[0]
            txt += f" <🧑🏼‍🏭 {od.get_name(od.get_source(worker))}>"
        if od.get_incoming(station, "queueOut"):
            txt += " --> 🏭"
            for out in od.get_incoming(station, "queueOut"):
                queue_obj = od.get_source(out)
                capacity = od.get_slot_value(queue_obj, "capacity")
                filled = len(od.get_incoming(queue_obj, "inQueue"))
                txt += f" {od.get_name(queue_obj)} ({'⚙️'*filled}{'🔲'*(capacity-filled)})"
            txt += " 🏭"
        txt += "\n"

    return txt


def get_filename(name, kind):
    file_dir = os.path.dirname(__file__)
    return f"{file_dir}/rules/r_{name}_{kind}.od"


def get_rules(current_state, rt_mm):
    print("Loading rules")
    rt_mm_ramified = ramify(current_state, rt_mm)
    matcher_rewriter = RuleMatcherRewriter(current_state, rt_mm, rt_mm_ramified)

    rules0 = load_rules(current_state, get_filename, rt_mm_ramified,
                        ['store'])
    rules1 = load_rules(current_state, get_filename, rt_mm_ramified,
                        ['accept', 'reject', 'finish_assembly'])
    rules2 = load_rules(current_state, get_filename, rt_mm_ramified,
                        ['process'])
    rules3 = load_rules(current_state, get_filename, rt_mm_ramified,
                        ['start_assembly', 'start_inspection'])
    rules4 = load_rules(current_state, get_filename, rt_mm_ramified,
                        ['add_A', 'add_B'])

    return PriorityActionGenerator(matcher_rewriter, [rules0, rules1, rules2, rules3, rules4])


state = DevState()
scd_mm = bootstrap_scd(state)

# Static Models
fact_mm = parse_and_check(state, factory_mm, scd_mm, "MM for Factory")
fact_m = parse_and_check(state, small_factory_m, fact_mm, "Model for Factory")
# Runtime Models
rt_fact_mm = parse_and_check(state, factory_rt_mm, scd_mm, "Runtime MM for Factory")
rt_fact_m = parse_and_check(state, small_factory_rt_m, rt_fact_mm, "Runtime Model for Factory")

print("About to start the simulation...")
rule_generator = get_rules(state, rt_fact_mm)
rule_sim = FactorySimulator(
    action_generator=rule_generator,
    termination_condition=TerminationCondition(state, rt_fact_mm),
    renderer=render_text
)

factory_od = ODAPI(state, rt_fact_m, rt_fact_mm)
rule_sim.run(factory_od)
