Files
Adat/gateware/channel_stream_combiner.py
2026-03-13 19:58:28 +03:00

154 lines
7.5 KiB
Python

from amaranth import *
from amaranth.build import Platform
from amlib.stream import StreamInterface
from amlib.test import GatewareTestCase, sync_test_case
class ChannelStreamCombiner(Elaboratable):
SAMPLE_WIDTH = 24
def __init__(self, no_lower_channels, no_upper_channels):
self.no_lower_channels = no_lower_channels
self.no_upper_channels = no_upper_channels
self.lower_channel_bits = Shape.cast(range(no_lower_channels)).width
self.upper_channel_bits = Shape.cast(range(no_upper_channels)).width
self.combined_channel_bits = Shape.cast(range(no_lower_channels + no_upper_channels)).width
self.lower_channel_stream_in = StreamInterface(name="lower_channels",
payload_width=self.SAMPLE_WIDTH,
extra_fields=[("channel_nr", self.lower_channel_bits)])
self.upper_channels_active_in = Signal()
self.upper_channel_stream_in = StreamInterface(name="upper_channels",
payload_width=self.SAMPLE_WIDTH,
extra_fields=[("channel_nr", self.lower_channel_bits)])
self.combined_channel_stream_out = StreamInterface(name="combined_channels",
payload_width=self.SAMPLE_WIDTH,
extra_fields=[("channel_nr", self.combined_channel_bits)])
# debug signals
self.state = Signal(range(3))
self.upper_channel_counter = Signal(self.upper_channel_bits)
# debug signals
self.state = Signal(range(3))
def elaborate(self, platform: Platform) -> Module:
m = Module()
upper_channel_counter = Signal(self.upper_channel_bits)
m.d.comb += self.upper_channel_counter.eq(upper_channel_counter)
with m.FSM() as fsm:
m.d.comb += self.state.eq(fsm.state)
with m.State("LOWER_CHANNELS"):
with m.If( self.combined_channel_stream_out.ready
& self.lower_channel_stream_in.valid):
m.d.comb += [
self.combined_channel_stream_out.payload.eq(self.lower_channel_stream_in.payload),
self.combined_channel_stream_out.channel_nr.eq(self.lower_channel_stream_in.channel_nr),
self.lower_channel_stream_in.ready.eq(1),
self.combined_channel_stream_out.valid.eq(1),
self.combined_channel_stream_out.first.eq(self.lower_channel_stream_in.first),
self.combined_channel_stream_out.last.eq(0),
]
with m.If(self.lower_channel_stream_in.last):
m.d.sync += upper_channel_counter.eq(0)
with m.If(self.upper_channels_active_in):
m.next = "UPPER_CHANNELS"
with m.Else():
m.next = "FILL_UPPER"
with m.State("UPPER_CHANNELS"):
with m.If(self.combined_channel_stream_out.ready):
with m.If(self.upper_channels_active_in):
with m.If(self.upper_channel_stream_in.valid):
m.d.comb += [
self.combined_channel_stream_out.payload.eq(self.upper_channel_stream_in.payload),
self.combined_channel_stream_out.channel_nr.eq(\
self.no_lower_channels +
self.upper_channel_stream_in.channel_nr),
self.upper_channel_stream_in.ready.eq(1),
self.combined_channel_stream_out.valid.eq(1),
self.combined_channel_stream_out.first.eq(0),
self.combined_channel_stream_out.last.eq(self.upper_channel_stream_in.last),
]
with m.If(self.upper_channel_stream_in.last):
with m.If(self.upper_channel_stream_in.channel_nr == 1):
m.d.sync += upper_channel_counter.eq(2)
m.next = "FILL_UPPER"
with m.Else():
m.next = "LOWER_CHANNELS"
# USB is not active: Fill in zeros
with m.Else():
m.next = "FILL_UPPER"
with m.State("FILL_UPPER"):
with m.If(self.combined_channel_stream_out.ready):
with m.If( ~self.upper_channels_active_in
& self.upper_channel_stream_in.valid):
# we just drain all stale data from the upstream FIFOs
# if the upper channels are not active
m.d.comb += self.upper_channel_stream_in.ready.eq(1)
m.d.sync += upper_channel_counter.eq(upper_channel_counter + 1)
m.d.comb += [
self.combined_channel_stream_out.payload.eq(0),
self.combined_channel_stream_out.channel_nr.eq(\
self.no_lower_channels +
upper_channel_counter),
self.combined_channel_stream_out.valid.eq(1),
self.combined_channel_stream_out.first.eq(0),
self.combined_channel_stream_out.last.eq(0),
]
last_channel = upper_channel_counter >= (self.no_upper_channels - 1)
with m.If(last_channel):
m.d.comb += self.combined_channel_stream_out.last.eq(1)
m.next = "LOWER_CHANNELS"
return m
class ChannelStreamCombinerTest(GatewareTestCase):
FRAGMENT_UNDER_TEST = ChannelStreamCombiner
FRAGMENT_ARGUMENTS = dict(no_lower_channels=32, no_upper_channels=6)
def send_lower_frame(self, sample: int, channel: int, wait=False):
yield self.dut.lower_channel_stream_in.channel_nr.eq(channel)
yield self.dut.lower_channel_stream_in.payload.eq(sample)
yield self.dut.lower_channel_stream_in.valid.eq(1)
yield self.dut.lower_channel_stream_in.first.eq(channel == 0)
yield self.dut.lower_channel_stream_in.last.eq(channel == 31)
yield
yield self.dut.lower_channel_stream_in.valid.eq(0)
if wait:
yield
def send_upper_frame(self, sample: int, channel: int, last_channel: int=6, wait=False):
yield self.dut.upper_channel_stream_in.channel_nr.eq(channel)
yield self.dut.upper_channel_stream_in.payload.eq(sample)
yield self.dut.upper_channel_stream_in.valid.eq(1)
yield self.dut.upper_channel_stream_in.first.eq(channel == 0)
yield self.dut.upper_channel_stream_in.last.eq(channel == last_channel)
yield
yield self.dut.upper_channel_stream_in.valid.eq(0)
if wait:
yield
@sync_test_case
def test_smoke(self):
dut = self.dut
yield
yield dut.combined_channel_stream_out.ready.eq(1)
for channel in range(32):
yield from self.send_lower_frame(channel, channel)
yield from self.advance_cycles(13)
for channel in range(32):
yield from self.send_lower_frame(channel, channel)
yield from self.advance_cycles(13)