summaryrefslogtreecommitdiff
path: root/gfx_asm/assembler.py
diff options
context:
space:
mode:
authorAlejandro Soto <alejandro@34project.org>2023-11-22 06:45:05 -0600
committerAlejandro Soto <alejandro@34project.org>2023-11-22 06:45:05 -0600
commit5116e5c80541f26d0f87a535539147030ecd2fed (patch)
treedb20d6e990518463d39f5e803eb687766b413ac7 /gfx_asm/assembler.py
parentf81b6c966935296601aee466d33525458d174415 (diff)
gfx_asm: implement
Diffstat (limited to '')
-rw-r--r--gfx_asm/assembler.py352
1 files changed, 352 insertions, 0 deletions
diff --git a/gfx_asm/assembler.py b/gfx_asm/assembler.py
new file mode 100644
index 0000000..51ec648
--- /dev/null
+++ b/gfx_asm/assembler.py
@@ -0,0 +1,352 @@
+import ast, sys, string, struct
+
+REG_STACK = 14
+REG_LINK = 15
+LABEL_CHARSET = string.ascii_letters + '_'
+
+REG_MAP = {
+ 'm0': 0,
+ 'm1': 1,
+ 'm2': 2,
+ 'm3': 3,
+ 'm4': 4,
+ 'm5': 5,
+ 'm6': 6,
+ 'm7': 7,
+}
+
+
+class Ins:
+ def __init__(self, *args, name, line, addr):
+ self.name = name
+ self.line = line
+ self.addr = addr
+ self.args = iter(args)
+
+ def imm_pool(self, pool):
+ pass
+
+ def length(self):
+ return 1
+
+ def stop(self):
+ try:
+ next(self.args)
+ except StopIteration:
+ pass
+ else:
+ self.error(f"Too many arguments")
+
+ def next(self, *, optional=False):
+ try:
+ return next(self.args)
+ except StopIteration:
+ if optional:
+ return None
+
+ self.error(f"Missing arguments")
+
+ def error(self, msg):
+ fail(self.line, f"{self.name}: {msg}")
+
+ def parse_addr(self, *, zero=True):
+ arg = self.next()
+
+ if len(arg) < 2 or arg[0] != "[" or arg[-1] != "]":
+ self.error(f"Invalid syntax: bad addressing mode: {repr(arg)}")
+
+ return self.parse_reg(arg=arg[1:-1], zero=zero)
+
+ def parse_imm(self, *, zero=True):
+ arg, bad = self.next(), False
+
+ try:
+ imm = int(arg, 0)
+ except ValueError:
+ bad = True
+
+ if bad:
+ try:
+ imm = ast.literal_eval(arg)
+ if type(imm) is str:
+ imm = imm.encode('ascii')
+ if len(imm) == 1:
+ imm = imm[0]
+ bad = False
+ except:
+ pass
+
+ if bad:
+ self.error(f"Invalid immediate value: {repr(arg)}")
+ elif not zero and not imm:
+ self.error("Immediate value must not be 0.")
+ elif not (-(1 << 31) <= imm <= (1 << 32) - 1):
+ self.error(f"Immediate exceeds 32 bits: {imm}")
+
+ return imm
+
+ def parse_reg(self, *, zero=True, arg=None, expect=None, optional=False):
+ if not arg:
+ arg = self.next(optional=optional)
+ if arg is None:
+ return None
+
+ arg = arg.lower()
+ if (reg := REG_MAP.get(arg)) is None:
+ self.error(f"Invalid register: {repr(arg)}")
+ elif not zero and not reg:
+ self.error("Register must not be r0")
+ elif expect is not None and reg != expect:
+ self.error(f"Expected register r{expect}, got r{reg}")
+
+ return reg
+
+ def parse_target(self):
+ arg = self.next()
+ if arg == '.':
+ return self.addr
+
+ if not arg or any(c not in LABEL_CHARSET for c in arg):
+ self.error(f"Invalid label: {repr(arg)}")
+
+ return arg
+
+ def encode_reg(self, reg):
+ return self.encode_unsigned(reg, 3)
+
+ def encode_rel(self, labels, label, size, *, offset=1):
+ addr = labels.get(label) if type(label) is str else label
+
+ if addr is None:
+ self.error(f"Undefined reference to {repr(label)}")
+
+ return self.encode_signed(addr - self.addr - offset, size, tag="Jump")
+
+ def encode_signed(self, val, size, *, tag="Value"):
+ lo, hi = -(1 << (size - 1)), (1 << (size - 1)) - 1
+ if not (lo <= val <= hi):
+ self.error(f"{tag} out of range [{lo}, {hi}]: {val}")
+
+ elif val < 0:
+ val += 1 << size
+
+ return self.encode_unsigned(val, size)
+
+ def encode_unsigned(self, val, size):
+ hi = (1 << size) - 1
+ if not (0 <= val <= hi):
+ self.error(f"Value out of range [0, {hi}]: {val}")
+
+ return bin(val)[2:].zfill(size)
+
+
+class Select(Ins):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.dst = self.parse_reg()
+ self.src_a = self.parse_reg()
+ self.src_b = self.parse_reg()
+
+ components = {'a': '0', 'b': '1'}
+
+ arg = self.next()
+ self.select = [components.get(v) for v in arg.lower()]
+
+ if len(self.select) != 4 or any(v is None for v in self.select):
+ self.error(f"Bad select mask: {repr(arg)}")
+
+
+ def encode(self, labels):
+ dst = self.encode_reg(self.dst)
+ src_a, src_b = self.encode_reg(self.src_a), self.encode_reg(self.src_b)
+ return ('00000000', ''.join(self.select), '0', src_b, '0', src_a, '0', dst, '00000001')
+
+
+class Swizzle(Ins):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.dst = self.parse_reg()
+ self.src = self.parse_reg()
+
+ components = {'x': 3, 'y': 2, 'z': 1, 'w': 0}
+
+ mask_arg = self.next()
+ self.masks = [components.get(v) for v in mask_arg.lower()]
+
+ if len(self.masks) != 4 or any(v is None for v in self.masks):
+ self.error(f"Bad swizzle mask: {repr(mask_arg)}")
+
+ def encode(self, labels):
+ dst = self.encode_reg(self.dst)
+ src = self.encode_reg(self.src)
+ mask = ''.join(self.encode_unsigned(mask, 2) for mask in self.masks)
+ print(mask, file=sys.stderr)
+ return (mask, '000000000', src, '0', dst, '00000010')
+
+
+class Broadcast(Ins):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.dst = self.parse_reg()
+
+ imm = self.next()
+ try:
+ self.imm = float(imm)
+ except:
+ self.error(f"Invalid immediate value: {repr(imm)}")
+
+ def encode(self, labels):
+ imm = self.encode_unsigned(int.from_bytes(struct.pack('<e', self.imm), 'little'), 16)
+ dst = self.encode_reg(self.dst)
+ return (imm, '00000', dst, '00000100')
+
+
+class MatVec(Ins):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.dst = self.parse_reg()
+ self.src_a = self.parse_reg()
+ self.src_b = self.parse_reg()
+
+ def encode(self, labels):
+ dst = self.encode_reg(self.dst)
+ src_a, src_b = self.encode_reg(self.src_a), self.encode_reg(self.src_b)
+ return ('0000000000000', src_b, '0', src_a, '0', dst, '00001000')
+
+
+class Send(Ins):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.src = self.parse_reg()
+
+ def encode(self, labels):
+ src = self.encode_reg(self.src)
+ return ('00000000000000000', src, '000000010000')
+
+
+class Recv(Ins):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.dst = self.parse_reg()
+
+ def encode(self, labels):
+ dst = self.encode_reg(self.dst)
+ return ('000000000000000000000', dst, '00100000')
+
+
+ISA = {
+ "select": Select,
+ "swizzl": Swizzle,
+ "broadc": Broadcast,
+ "matvec": MatVec,
+ "send": Send,
+ "recv": Recv,
+}
+
+
+def fail(line, msg):
+ print("At line ", line, ": ", msg, sep="", file=sys.stderr)
+ sys.exit(1)
+
+
+def assemble(file):
+ pc = 0
+ insns = []
+ labels = {}
+ imm_labels = {}
+
+ def get_imm_label(imm):
+ nonlocal imm_labels
+
+ if imm < 0:
+ imm += 1 << 32
+
+ label = imm_labels.get(imm)
+ if not label:
+ label = f'#{hex(imm)[2:].zfill(8)}'
+ imm_labels[imm] = label
+
+ return label
+
+ with open(file, "r") as src:
+ for lineno, line in enumerate(src, start=1):
+ ## comments
+ if (i := line.find("!")) != -1:
+ line = line[:i]
+
+ line = line.strip()
+
+ if len(line) > 1 and line[-1] == ":":
+ label = line[:-1]
+
+ if any(c not in LABEL_CHARSET for c in label):
+ fail(lineno, f"Invalid label: {repr(label)}")
+ elif label in labels:
+ fail(lineno, f"Label already in use: {repr(label)}")
+
+ labels[label] = pc
+
+ continue
+
+ line = line.split(maxsplit=1)
+
+ ## empty lines
+ if not line:
+ continue
+
+ args = (arg.strip() for arg in line[1].split(",")) if len(line) > 1 else ()
+ name = line[0].lower()
+
+ ctor = ISA.get(name)
+
+ if not ctor:
+ fail(lineno, f"Unknown instruction: {repr(name)}")
+
+ insn = ctor(*args, name=name, line=lineno, addr=pc)
+ insn.stop()
+ insn.imm_pool(get_imm_label)
+
+ insns.append(insn)
+ pc += insn.length()
+
+ # Inmediatos tienen que estar alineados a words
+ imm_pool_padding = bool(pc & 1)
+ if imm_pool_padding:
+ pc += 1
+
+ imm_labels = list(imm_labels.items())
+ for imm, label in imm_labels:
+ labels[label] = pc
+ pc += 2
+
+ output = bytearray()
+
+ for insn in insns:
+ encs = insn.encode(labels)
+
+ if type(encs) is not list:
+ encs = [encs]
+
+ assert len(encs) == insn.length()
+
+ for enc in encs:
+ enc = "".join(enc)
+ assert len(enc) == 32 and all(c in ("0", "1") for c in enc)
+
+ output.extend(int(enc, 2).to_bytes(4, "little"))
+
+ return output
+
+
+def main():
+ sys.stdout.buffer.write(assemble(sys.argv[1]))
+
+
+if __name__ == "__main__":
+ main()