"""
CBD to C Code Generator - Template

This template provides the structure for a CBD-to-C transpiler.
Your task is to complete the TODO sections to implement:
1. Dependency graph construction
2. Topological sorting
3. Algebraic loop detection
4. Block code generation
"""

from pyCBD.Core import CBD, BaseBlock
from pyCBD.lib.std import *


class CBD2C:
    """
    Converts a CBD model to standalone C code.
    
    Args:
        cbd: The CBD model to convert (will be flattened)
        max_history: Maximum history size needed (default 2 for DelayBlock)
    """
    
    def __init__(self, cbd, max_history=2):
        self.cbd = cbd.flattened()  # Flatten hierarchical models
        self.max_history = max_history
        
        # Build signal mapping
        self.signal_map = {}
        self._build_signal_map()
        
        # TODO: Build dependency graph and schedule
        self.dep_graph = self._build_dependency_graph()
        self.schedule = self._topological_sort(self.dep_graph)
        self.algebraic_loops = self._detect_algebraic_loops(self.dep_graph)
    
    def _build_signal_map(self):
        """Build mapping from signals to array indices."""
        idx = 0
        for block in self.cbd.getBlocks():
            if isinstance(block, CBD):
                continue
            for port in block.getOutputPorts():
                signal_name = f"{block.getBlockName()}_{port.name}"
                self.signal_map[signal_name] = idx
                idx += 1
    
    def _build_dependency_graph(self):
        """
        Build a dependency graph for the flattened CBD model.
        
        The dependency graph should represent which blocks depend on which other blocks.
        
        Returns:
            A dependency graph structure (you decide the representation).
            Examples: adjacency list (dict), adjacency matrix, graph object, etc.
        
        Hints:
        - Iterate through all blocks in self.cbd.getBlocks()
        - For each block, check its input ports: block.getInputPorts()
        - For each input port, find the connected source: port.getIncoming()
        - If incoming is not None, add an edge in your graph
        - DelayBlocks have special behavior at iteration 0!
        """
        # TODO: Your implementation here
        raise NotImplementedError("You need to implement _build_dependency_graph")
    
    def _topological_sort(self, dep_graph):
        """
        Perform topological sort on the dependency graph to determine
        the order in which blocks should be computed.
        
        Args:
            dep_graph: The dependency graph from _build_dependency_graph
        
        Returns:
            A list of blocks in topological order, where each element is
            a list (component) of blocks. Single-element lists are simple
            blocks, multi-element lists are strongly connected components.
        
        You can use:
        - Kahn's algorithm
        - DFS-based topological sort
        - Tarjan's algorithm (also finds SCCs)
        
        Hints:
        - Build a list of blocks with no dependencies (in-degree = 0)
        - Process them one by one, removing them from the graph
        - Add newly freed blocks to the list
        - If graph is not empty at end, there's a cycle
        """
        # TODO: Your implementation here
        raise NotImplementedError("You need to implement _topological_sort")
    
    def _detect_algebraic_loops(self, dep_graph):
        """
        Detect algebraic loops (strongly connected components) in the
        dependency graph.
        
        Args:
            dep_graph: The dependency graph
        
        Returns:
            A list of strongly connected components, where each SCC is
            a list of blocks that form a cycle.
        
        You can use Tarjan's algorithm.
        
        Hints:
        - An SCC with more than 1 block is an algebraic loop
        - A single block that depends on itself is also a loop
        - DelayBlocks break cycles (special dependency at i=0)
        """
        # TODO: Your implementation here
        raise NotImplementedError("You need to implement _detect_algebraic_loops")
    
    def _get_signal_index(self, block, port_name):
        """Get the array index for a block's output port."""
        signal_name = f"{block.getBlockName()}_{port_name}"
        return self.signal_map.get(signal_name, 0)
    
    def _get_input_signal_ref(self, block, input_port_name):
        """Get the C code reference to an input signal (using array indices)."""
        port = block.getInputPortByName(input_port_name)
        incoming = port.getIncoming()
        if incoming is None:
            return "0.0"  # Unconnected input
        
        source_block = incoming.source.block
        source_port = incoming.source.name
        source_idx = self._get_signal_index(source_block, source_port)
        return f"state.signals[{source_idx}][state.current_idx]"
    
    def generate(self, output_file=None):
        """
        Generate C code for the CBD model.
        
        Args:
            output_file: Path to write C file to (None = return as string)
            
        Returns:
            Generated C code as string
        """
        code_parts = []
        
        # Header
        code_parts.append(self._generate_header())
        
        # State struct
        code_parts.append(self._generate_state_struct())
        
        # Initialization function
        code_parts.append(self._generate_init_function())
        
        # Step function (main computation)
        code_parts.append(self._generate_step_function())
        
        # Main function
        code_parts.append(self._generate_main_function())
        
        code = '\n'.join(code_parts)
        
        if output_file:
            with open(output_file, 'w') as f:
                f.write(code)
        
        return code
    
    def _generate_header(self):
        """Generate C header with includes and comments."""
        return f"""/*
 * Generated C code for CBD model: {self.cbd.getBlockName()}
 * 
 * This code was automatically generated from a pyCBD model.
 * 
 * Compilation: gcc -o cbd_sim cbd_sim.c -lm
 * Usage: ./cbd_sim [num_iterations] [delta_t] [trace_file]
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>

/* Configuration */
#define MAX_ITERATIONS 1000
#define DEFAULT_DELTA_T 1.0
#define HISTORY_SIZE {self.max_history}
#define NUM_SIGNALS {len(self.signal_map)}
"""
    
    def _generate_state_struct(self):
        """Generate the state struct definition."""
        return """
/* CBD State Structure */
typedef struct {
    double signals[NUM_SIGNALS][HISTORY_SIZE];  /* Signal values with history */
    int current_idx;      /* Ring buffer index for current values */
    double time;          /* Current simulation time */
    double delta_t;       /* Time step size */
    int iteration;        /* Current iteration number */
} CBDState;
"""
    
    def _generate_init_function(self):
        """Generate state initialization function."""
        return """
/* Initialize CBD state */
void cbd_init(CBDState* state, double delta_t) {
    state->current_idx = 0;
    state->time = 0.0;
    state->delta_t = delta_t;
    state->iteration = 0;
    
    /* Initialize all signals to 0.0 */
    for (int i = 0; i < NUM_SIGNALS; i++) {
        for (int j = 0; j < HISTORY_SIZE; j++) {
            state->signals[i][j] = 0.0;
        }
    }
}
"""
    
    def _generate_step_function(self):
        """Generate the main computation step function."""
        lines = ["""
/* Execute one simulation step */
void cbd_step(CBDState* state) {
    /* Compute blocks in dependency order */
"""]
        
        for component in self.schedule:
            if len(component) == 1:
                # Single block, no algebraic loop
                block = component[0]
                if isinstance(block, BaseBlock) and not isinstance(block, CBD):
                    lines.append(f"    /* Block: {block.getBlockName()} ({block.getBlockType()}) */")
                    lines.append(self._generate_block_computation(block))
            else:
                # Algebraic loop
                lines.append(f"    /* WARNING: Algebraic loop detected */")
                for block in component:
                    if isinstance(block, BaseBlock) and not isinstance(block, CBD):
                        lines.append(self._generate_block_computation(block))
        
        lines.append("""
    /* Update iteration */
    state->iteration++;
}
""")
        return '\n'.join(lines)
    
    def _generate_block_computation(self, block):
        """
        Generate C code for a specific block's computation.
        
        TODO: Complete this function for all block types you use.
        
        The template provides some examples. You need to add more block types.
        """
        block_type = block.getBlockType()
        block_name = block.getBlockName()
        output_idx = self._get_signal_index(block, "OUT1")
        
        # ConstantBlock
        if isinstance(block, ConstantBlock):
            value = block.getValue()
            return f"    state.signals[{output_idx}][state.current_idx] = {value};\n"
        
        # AdderBlock
        elif isinstance(block, AdderBlock):
            num_inputs = block.getNumberOfInputs()
            input_refs = [self._get_input_signal_ref(block, f"IN{i}") for i in range(1, num_inputs + 1)]
            expr = " + ".join(input_refs)
            return f"    state.signals[{output_idx}][state.current_idx] = {expr};\n"
        
        # DelayBlock
        elif isinstance(block, DelayBlock):
            ic_ref = self._get_input_signal_ref(block, "IC")
            in1_port = block.getPortConnectedToInput("IN1")
            in1_idx = self._get_signal_index(in1_port.block, in1_port.name)
            return f"""    if (state.iteration == 0) {{
        state.signals[{output_idx}][state.current_idx] = {ic_ref};
    }} else {{
        state.signals[{output_idx}][state.current_idx] = state.signals[{in1_idx}][(state.current_idx + HISTORY_SIZE - 1) % HISTORY_SIZE];
    }}
"""
        
        # TODO: Add more block types here
        # - ProductBlock
        # - NegatorBlock
        # - InverterBlock
        # - GenericBlock (for sin, cos, etc.)
        # - LessThanBlock, EqualsBlock
        # - AbsBlock, IntBlock
        # - etc.
        
        else:
            return f"    /* TODO: Implement {block_type} */\n    state.signals[{output_idx}][state.current_idx] = 0.0;\n"
    
    def _generate_main_function(self):
        """Generate the main function with simulation loop and tracing."""
        signal_names = sorted(self.signal_map.keys(), key=lambda x: self.signal_map[x])
        header_line = "iteration,time," + ",".join(signal_names)
        
        trace_output = []
        for signal_name in signal_names:
            idx = self.signal_map[signal_name]
            trace_output.append(f'fprintf(trace_file, ",%f", state.signals[{idx}][state.current_idx]);')
        trace_code = '\n        '.join(trace_output)
        
        return f"""
/* Main simulation function */
int main(int argc, char* argv[]) {{
    /* Parse command line arguments */
    int num_iterations = MAX_ITERATIONS;
    double delta_t = DEFAULT_DELTA_T;
    FILE* trace_file = NULL;
    
    if (argc > 1) num_iterations = atoi(argv[1]);
    if (argc > 2) delta_t = atof(argv[2]);
    if (argc > 3) {{
        trace_file = fopen(argv[3], "w");
        if (!trace_file) {{
            fprintf(stderr, "Error: Cannot open trace file '%s'\\n", argv[3]);
            return 1;
        }}
    }} else {{
        trace_file = stdout;
    }}
    
    /* Initialize CBD state */
    CBDState state;
    cbd_init(&state, delta_t);
    
    /* Write CSV header */
    fprintf(trace_file, "{header_line}\\n");
    
    /* Simulation loop */
    for (int iter = 0; iter < num_iterations; iter++) {{
        /* Execute one step */
        cbd_step(&state);
        
        /* Write trace data */
        fprintf(trace_file, "%d,%f", state.iteration, state.time);
        {trace_code}
        fprintf(trace_file, "\\n");
        
        /* Advance time and ring buffer */
        state.time += state.delta_t;
        state.current_idx = (state.current_idx + 1) % HISTORY_SIZE;
    }}
    
    /* Cleanup */
    if (trace_file != stdout) {{
        fclose(trace_file);
    }}
    
    printf("Simulation complete: %d iterations\\n", num_iterations);
    return 0;
}}
"""


# Convenience function
def generate_c_code(cbd, output_file=None, max_history=2):
    """
    Generate C code for a CBD model.
    
    Args:
        cbd: The CBD model to convert
        output_file: Path to write C file (None = return string)
        max_history: Maximum history size for signal storage
        
    Returns:
        Generated C code as string
        
    Example:
        >>> from pyCBD.Core import CBD
        >>> from pyCBD.lib.std import *
        >>> 
        >>> cbd = CBD("example")
        >>> cbd.addBlock(ConstantBlock("c1", 5.0))
        >>> cbd.addBlock(ConstantBlock("c2", 3.0))
        >>> cbd.addBlock(AdderBlock("add"))
        >>> cbd.addConnection("c1", "add", input_port_name="IN1")
        >>> cbd.addConnection("c2", "add", input_port_name="IN2")
        >>> 
        >>> code = generate_c_code(cbd, "example.c")
        >>> # Compile: gcc -o example example.c -lm
        >>> # Run: ./example 100 0.1 trace.csv
    """
    generator = CBD2C(cbd, max_history)
    return generator.generate(output_file)


if __name__ == "__main__":
    print("CBD2C Template - Complete the TODO sections to implement the transpiler!")
    print("\nTODO List:")
    print("1. Implement _build_dependency_graph()")
    print("2. Implement _topological_sort()")
    print("3. Implement _detect_algebraic_loops()")
    print("4. Complete _generate_block_computation() for all block types")
    print("\nGood luck!")

