#!/usr/bin/env python3 import sys import re import os import ast import argparse import subprocess import difflib import string import itertools import threading import queue import time def fail(msg): print(msg, file=sys.stderr) sys.exit(1) try: import attr from colorama import Fore, Style, Back import ansiwrap import watchdog except ModuleNotFoundError as e: fail(f"Missing prerequisite python module {e.name}. " "Run `python3 -m pip install --user colorama ansiwrap attrs watchdog` to install prerequisites.") # Prefer to use diff_settings.py from the current working directory sys.path.insert(0, '.') try: import diff_settings except ModuleNotFoundError: fail("Unable to find diff_settings.py in the same directory.") # ==== CONFIG ==== parser = argparse.ArgumentParser( description="Diff MIPS assembly.") parser.add_argument('start', help="Function name or address to start diffing from.") parser.add_argument('end', nargs='?', help="Address to end diff at.") parser.add_argument('-o', dest='diff_obj', action='store_true', help="Diff .o files rather than a whole binary. This makes it possible to see symbol names. (Recommended)") parser.add_argument('--base-asm', dest='base_asm', metavar='FILE', help="Read assembly from given file instead of configured base img.") parser.add_argument('--write-asm', dest='write_asm', metavar='FILE', help="Write the current assembly output to file, e.g. for use with --base-asm.") parser.add_argument('-m', '--make', dest='make', action='store_true', help="Automatically run 'make' on the .o file or binary before diffing.") parser.add_argument('-l', '--skip-lines', dest='skip_lines', type=int, default=0, help="Skip the first N lines of output.") parser.add_argument('-f', '--stop-jr-ra', dest='stop_jrra', action='store_true', help="Stop disassembling at the first 'jr ra'. Some functions have multiple return points, so use with care!") parser.add_argument('-i', '--ignore-large-imms', dest='ignore_large_imms', action='store_true', help="Pretend all large enough immediates are the same.") parser.add_argument('-B', '--no-show-branches', dest='show_branches', action='store_false', help="Don't visualize branches/branch targets.") parser.add_argument('-S', '--base-shift', dest='base_shift', type=str, default='0', help="Diff position X in our img against position X + shift in the base img. " "Arithmetic is allowed, so e.g. |-S \"0x1234 - 0x4321\"| is a reasonable " "flag to pass if it is known that position 0x1234 in the base img syncs " "up with position 0x4321 in our img. Not supported together with -o.") parser.add_argument('-w', '--watch', dest='watch', action='store_true', help="Automatically update when source/object files change. " "Recommended in combination with -m.") parser.add_argument('--width', dest='column_width', type=int, default=50, help="Sets the width of the left and right view column.") # Project-specific flags, e.g. different versions/make arguments. if hasattr(diff_settings, "add_custom_arguments"): diff_settings.add_custom_arguments(parser) args = parser.parse_args() # Set imgs, map file and make flags in a project-specific manner. config = {} diff_settings.apply(config, args) baseimg = config.get('baseimg', None) myimg = config.get('myimg', None) mapfile = config.get('mapfile', None) makeflags = config.get('makeflags', []) source_directories = config.get('source_directories', None) MAX_FUNCTION_SIZE_LINES = 1024 MAX_FUNCTION_SIZE_BYTES = 1024 * 4 COLOR_ROTATION = [ Fore.MAGENTA, Fore.CYAN, Fore.GREEN, Fore.RED, Fore.LIGHTYELLOW_EX, Fore.LIGHTMAGENTA_EX, Fore.LIGHTCYAN_EX, Fore.LIGHTGREEN_EX, Fore.LIGHTBLACK_EX, ] BUFFER_CMD = ["tail", "-c", str(10**9)] LESS_CMD = ["less", "-Ric"] DEBOUNCE_DELAY = 0.1 FS_WATCH_EXTENSIONS = ['.c', '.h'] # ==== LOGIC ==== binutils_prefix = None for binutils_cand in ['mips-linux-gnu-', 'mips64-elf-']: try: subprocess.check_call([binutils_cand + "objdump", "--version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) binutils_prefix = binutils_cand break except subprocess.CalledProcessError: pass except FileNotFoundError: pass if not binutils_prefix: fail("Missing binutils; please ensure mips-linux-gnu-objdump or mips64-elf-objdump exist.") def eval_int(expr, emsg=None): try: ret = ast.literal_eval(expr) if not isinstance(ret, int): raise Exception("not an integer") return ret except Exception: if emsg is not None: fail(emsg) return None def run_make(target, capture_output=False): if capture_output: return subprocess.run(["make"] + makeflags + [target], stderr=subprocess.PIPE, stdout=subprocess.PIPE) else: subprocess.check_call(["make"] + makeflags + [target]) def restrict_to_function(dump, fn_name): out = [] search = f'<{fn_name}>:' found = False for line in dump.split('\n'): if found: if len(out) >= MAX_FUNCTION_SIZE_LINES: break out.append(line) elif search in line: found = True return '\n'.join(out) def run_objdump(cmd): flags, target, restrict = cmd out = subprocess.check_output([binutils_prefix + "objdump"] + flags + [target], universal_newlines=True) if restrict is not None: return restrict_to_function(out, restrict) return out base_shift = eval_int(args.base_shift, "Failed to parse --base-shift (-S) argument as an integer.") def search_map_file(fn_name): if not mapfile: fail(f"No map file configured; cannot find function {fn_name}.") try: with open(mapfile) as f: lines = f.read().split('\n') except Exception: fail(f"Failed to open map file {mapfile} for reading.") try: cur_objfile = None ram_to_rom = None cands = [] last_line = '' for line in lines: if line.startswith(' .text'): cur_objfile = line.split()[3] if 'load address' in line: tokens = last_line.split() + line.split() ram = int(tokens[1], 0) rom = int(tokens[5], 0) ram_to_rom = rom - ram if line.endswith(' ' + fn_name): ram = int(line.split()[0], 0) if cur_objfile is not None and ram_to_rom is not None: cands.append((cur_objfile, ram + ram_to_rom)) last_line = line except Exception as e: import traceback traceback.print_exc() fail(f"Internal error while parsing map file") if len(cands) > 1: fail(f"Found multiple occurrences of function {fn_name} in map file.") if len(cands) == 1: return cands[0] return None, None def dump_objfile(): if base_shift: fail("--base-shift not compatible with -o") if args.end is not None: fail("end address not supported together with -o") if args.start.startswith('0'): fail("numerical start address not supported with -o; pass a function name") objfile, _ = search_map_file(args.start) if not objfile: fail("Not able to find .o file for function.") if args.make: run_make(objfile) if not os.path.isfile(objfile): fail("Not able to find .o file for function.") refobjfile = "expected/" + objfile if not os.path.isfile(refobjfile): fail(f'Please ensure an OK .o file exists at "{refobjfile}".') objdump_flags = ["-drz"] return ( objfile, (objdump_flags, refobjfile, args.start), (objdump_flags, objfile, args.start) ) def dump_binary(): if not baseimg or not myimg: fail("Missing myimg/baseimg in config.") if args.make: run_make(myimg) start_addr = eval_int(args.start) if start_addr is None: _, start_addr = search_map_file(args.start) if start_addr is None: fail("Not able to find function in map file.") if args.end is not None: end_addr = eval_int(args.end, "End address must be an integer expression.") else: end_addr = start_addr + MAX_FUNCTION_SIZE_BYTES objdump_flags = ['-Dz', '-bbinary', '-mmips', '-EB'] flags1 = [f"--start-address={start_addr + base_shift}", f"--stop-address={end_addr + base_shift}"] flags2 = [f"--start-address={start_addr}", f"--stop-address={end_addr}"] return ( myimg, (objdump_flags + flags1, baseimg, None), (objdump_flags + flags2, myimg, None) ) # Alignment with ANSI colors is broken, let's fix it. def ansi_ljust(s, width): needed = width - ansiwrap.ansilen(s) if needed > 0: return s + ' ' * needed else: return s re_int = re.compile(r'[0-9]+') re_comments = re.compile(r'<.*?>') re_regs = re.compile(r'\b(a[0-3]|t[0-9]|s[0-7]|at|v[01]|f[12]?[0-9]|f3[01]|fp)\b') re_sprel = re.compile(r',([1-9][0-9]*|0x[1-9a-f][0-9a-f]*)\(sp\)') re_large_imm = re.compile(r'-?[1-9][0-9]{2,}|-?0x[0-9a-f]{3,}') forbidden = set(string.ascii_letters + '_') branch_likely_instructions = set([ 'beql', 'bnel', 'beqzl', 'bnezl', 'bgezl', 'bgtzl', 'blezl', 'bltzl', 'bc1tl', 'bc1fl' ]) branch_instructions = set([ 'b', 'beq', 'bne', 'beqz', 'bnez', 'bgez', 'bgtz', 'blez', 'bltz', 'bc1t', 'bc1f' ] + list(branch_likely_instructions)) def hexify_int(row, pat): full = pat.group(0) if len(full) <= 1: # leave one-digit ints alone return full start, end = pat.span() if start and row[start - 1] in forbidden: return full if end < len(row) and row[end] in forbidden: return full return hex(int(full)) def parse_relocated_line(line): try: ind2 = line.rindex(',') except ValueError: ind2 = line.rindex('\t') before = line[:ind2+1] after = line[ind2+1:] ind2 = after.find('(') if ind2 == -1: imm, after = after, '' else: imm, after = after[:ind2], after[ind2:] if imm == '0x0': imm = '0' return before, imm, after def process_reloc(row, prev): before, imm, after = parse_relocated_line(prev) repl = row.split()[-1] if imm != '0': if before.strip() == 'jal' and not imm.startswith('0x'): imm = '0x' + imm repl += '+' + imm if int(imm,0) > 0 else imm if 'R_MIPS_LO16' in row: repl = f'%lo({repl})' elif 'R_MIPS_HI16' in row: # Ideally we'd pair up R_MIPS_LO16 and R_MIPS_HI16 to generate a # correct addend for each, but objdump doesn't give us the order of # the relocations, so we can't find the right LO16. :( repl = f'%hi({repl})' else: assert 'R_MIPS_26' in row, f"unknown relocation type '{row}'" return before + repl + after def process(lines): mnemonics = [] diff_rows = [] skip_next = False originals = [] line_nums = [] branch_targets = [] if not args.diff_obj: lines = lines[7:] if lines and not lines[-1]: lines.pop() for row in lines: if args.diff_obj and ('>:' in row or not row): continue if 'R_MIPS_' in row: if diff_rows[-1] != '': diff_rows[-1] = process_reloc(row, diff_rows[-1]) originals[-1] = process_reloc(row, originals[-1]) continue row = re.sub(re_comments, '', row) row = row.rstrip() tabs = row.split('\t') row = '\t'.join(tabs[2:]) line_num = tabs[0].strip() row_parts = row.split('\t', 1) mnemonic = row_parts[0].strip() if mnemonic not in branch_instructions: row = re.sub(re_int, lambda s: hexify_int(row, s), row) original = row if skip_next: skip_next = False row = '' mnemonic = '' if mnemonic in branch_likely_instructions: skip_next = True row = re.sub(re_regs, '', row) row = re.sub(re_sprel, ',addr(sp)', row) if args.ignore_large_imms: row = re.sub(re_large_imm, '', row) # Replace tabs with spaces mnemonics.append(mnemonic) diff_rows.append(row) originals.append(original) line_nums.append(line_num) if mnemonic in branch_instructions: target = row_parts[1].strip().split(',')[-1] if mnemonic in branch_likely_instructions: target = hex(int(target, 16) - 4)[2:] branch_targets.append(target) else: branch_targets.append(None) if args.stop_jrra and mnemonic == 'jr' and row_parts[1].strip() == 'ra': break # Cleanup whitespace originals = [original.strip() for original in originals] originals = [''.join(f'{o:<8s}' for o in original.split('\t')) for original in originals] # return diff_rows, diff_rows, line_nums return mnemonics, diff_rows, originals, line_nums, branch_targets def format_single_line_diff(line1, line2, column_width): return f"{ansi_ljust(line1,column_width)}{ansi_ljust(line2,column_width)}" class SymbolColorer: def __init__(self, base_index): self.color_index = base_index self.symbol_colors = {} def color_symbol(self, s, t=None): try: color = self.symbol_colors[s] except: color = COLOR_ROTATION[self.color_index % len(COLOR_ROTATION)] self.color_index += 1 self.symbol_colors[s] = color t = t or s return f'{color}{t}{Fore.RESET}' def normalize_large_imms(row): if args.ignore_large_imms: row = re.sub(re_large_imm, '', row) return row def do_diff(basedump, mydump): asm_lines1 = basedump.split('\n') asm_lines2 = mydump.split('\n') output = [] # TODO: status line? # output.append(sha1sum(mydump)) mnemonics1, asm_lines1, originals1, line_nums1, branch_targets1 = process(asm_lines1) mnemonics2, asm_lines2, originals2, line_nums2, branch_targets2 = process(asm_lines2) sc1 = SymbolColorer(0) sc2 = SymbolColorer(0) sc3 = SymbolColorer(4) sc4 = SymbolColorer(4) sc5 = SymbolColorer(0) sc6 = SymbolColorer(0) bts1 = set() bts2 = set() if args.show_branches: for (bts, btset, sc) in [(branch_targets1, bts1, sc5), (branch_targets2, bts2, sc6)]: for bt in bts: if bt is not None: btset.add(bt + ":") sc.color_symbol(bt + ":") differ: difflib.SequenceMatcher = difflib.SequenceMatcher(a=mnemonics1, b=mnemonics2, autojunk=False) for (tag, i1, i2, j1, j2) in differ.get_opcodes(): lines1 = asm_lines1[i1:i2] lines2 = asm_lines2[j1:j2] for k, (line1, line2) in enumerate(itertools.zip_longest(lines1, lines2)): if tag == 'replace': if line1 is None: tag = 'insert' elif line2 is None: tag = 'delete' try: original1 = originals1[i1+k] line_num1 = line_nums1[i1+k] except: original1 = '' line_num1 = '' try: original2 = originals2[j1+k] line_num2 = line_nums2[j1+k] except: original2 = '' line_num2 = '' line_color = Fore.RESET line_prefix = ' ' if line1 == line2: if normalize_large_imms(original1) == normalize_large_imms(original2): out1 = f'{original1}' out2 = f'{original2}' elif line1 == '': out1 = f'{Style.DIM}{original1}' out2 = f'{Style.DIM}{original2}' else: line_color = Fore.YELLOW line_prefix = 'r' out1 = f'{Fore.YELLOW}{original1}{Style.RESET_ALL}' out2 = f'{Fore.YELLOW}{original2}{Style.RESET_ALL}' out1 = re.sub(re_regs, lambda s: sc1.color_symbol(s.group()), out1) out2 = re.sub(re_regs, lambda s: sc2.color_symbol(s.group()), out2) out1 = re.sub(re_sprel, lambda s: sc3.color_symbol(s.group()), out1) out2 = re.sub(re_sprel, lambda s: sc4.color_symbol(s.group()), out2) elif tag in ['replace', 'equal']: line_prefix = '|' line_color = Fore.BLUE out1 = f"{Fore.BLUE}{original1}{Style.RESET_ALL}" out2 = f"{Fore.BLUE}{original2}{Style.RESET_ALL}" elif tag == 'delete': line_prefix = '<' line_color = Fore.RED out1 = f"{Fore.RED}{original1}{Style.RESET_ALL}" out2 = '' elif tag == 'insert': line_prefix = '>' line_color = Fore.GREEN out1 = '' out2 = f"{Fore.GREEN}{original2}{Style.RESET_ALL}" in_arrow1 = ' ' in_arrow2 = ' ' out_arrow1 = '' out_arrow2 = '' line_num1 = line_num1 if out1 else '' line_num2 = line_num2 if out2 else '' if args.show_branches and out1: if line_num1 in bts1: in_arrow1 = sc5.color_symbol(line_num1, '~>') if branch_targets1[i1+k] is not None: out_arrow1 = ' ' + sc5.color_symbol(branch_targets1[i1+k] + ":", '~>') if args.show_branches and out2: if line_num2 in bts2: in_arrow2 = sc6.color_symbol(line_num2, '~>') if branch_targets2[j1+k] is not None: out_arrow2 = ' ' + sc6.color_symbol(branch_targets2[j1+k] + ":", '~>') out1 = f"{line_color}{line_num1} {in_arrow1} {out1}{Style.RESET_ALL}{out_arrow1}" out2 = f"{line_color}{line_prefix} {line_num2} {in_arrow2} {out2}{Style.RESET_ALL}{out_arrow2}" output.append(format_single_line_diff(out1, out2, args.column_width)) return output[args.skip_lines:] def debounced_fs_watch(targets, outq, debounce_delay): import watchdog.events import watchdog.observers class WatchEventHandler(watchdog.events.FileSystemEventHandler): def __init__(self, queue, file_targets): self.queue = queue self.file_targets = file_targets def on_modified(self, ev): if isinstance(ev, watchdog.events.FileModifiedEvent): self.changed(ev.src_path) def on_moved(self, ev): if isinstance(ev, watchdog.events.FileMovedEvent): self.changed(ev.dest_path) def should_notify(self, path): for target in self.file_targets: if path == target: return True if args.make and any(path.endswith(suffix) for suffix in FS_WATCH_EXTENSIONS): return True return False def changed(self, path): if self.should_notify(path): self.queue.put(time.time()) def debounce_thread(): listenq = queue.Queue() file_targets = [] event_handler = WatchEventHandler(listenq, file_targets) observer = watchdog.observers.Observer() observed = set() for target in targets: if os.path.isdir(target): observer.schedule(event_handler, target, recursive=True) else: file_targets.append(target) target = os.path.dirname(target) if target not in observed: observed.add(target) observer.schedule(event_handler, target) observer.start() while True: t = listenq.get() more = True while more: delay = t + debounce_delay - time.time() if delay > 0: time.sleep(delay) # consume entire queue more = False try: while True: t = listenq.get(block=False) more = True except queue.Empty: pass outq.put(t) th = threading.Thread(target=debounce_thread, daemon=True) th.start() class Display(): def __init__(self, basedump, mydump): self.basedump = basedump self.mydump = mydump self.emsg = None def run_less(self): if self.emsg is not None: output = self.emsg else: output = '\n'.join(do_diff(self.basedump, self.mydump)) # Pipe the output through 'tail' and only then to less, to ensure the # write call doesn't block. ('tail' has to buffer all its input before # it starts writing.) This also means we don't have to deal with pipe # closure errors. buffer_proc = subprocess.Popen(BUFFER_CMD, stdin=subprocess.PIPE, stdout=subprocess.PIPE) less_proc = subprocess.Popen(LESS_CMD, stdin=buffer_proc.stdout) buffer_proc.stdin.write(output.encode()) buffer_proc.stdin.close() buffer_proc.stdout.close() return (buffer_proc, less_proc) def run_sync(self): proca, procb = self.run_less() procb.wait() proca.wait() def run_async(self, watch_queue): self.watch_queue = watch_queue self.ready_queue = queue.Queue() self.pending_update = None dthread = threading.Thread(target=self.display_thread) dthread.start() self.ready_queue.get() def display_thread(self): proca, procb = self.run_less() self.less_proc = procb self.ready_queue.put(0) while True: ret = procb.wait() proca.wait() self.less_proc = None if ret != 0: # fix the terminal os.system("tput reset") if ret != 0 and self.pending_update is not None: # killed by program with the intent to refresh msg, error = self.pending_update self.pending_update = None if not error: self.mydump = msg self.emsg = None else: self.emsg = msg proca, procb = self.run_less() self.less_proc = procb self.ready_queue.put(0) else: # terminated by user, or killed self.watch_queue.put(None) self.ready_queue.put(0) break def progress(self, msg): # Write message to top-left corner sys.stdout.write("\x1b7\x1b[1;1f{}\x1b8".format(msg + " ")) sys.stdout.flush() def update(self, text, error): if not error and not self.emsg and text == self.mydump: self.progress("Unchanged. ") return self.pending_update = (text, error) if not self.less_proc: return self.less_proc.kill() self.ready_queue.get() def terminate(self): if not self.less_proc: return self.less_proc.kill() self.ready_queue.get() def main(): if args.diff_obj: make_target, basecmd, mycmd = dump_objfile() else: make_target, basecmd, mycmd = dump_binary() if args.write_asm is not None: mydump = run_objdump(mycmd) with open(args.write_asm) as f: f.write(mydump) print(f"Wrote assembly to {args.write_asm}.") sys.exit(0) if args.base_asm is not None: with open(args.base_asm) as f: basedump = f.read() else: basedump = run_objdump(basecmd) mydump = run_objdump(mycmd) display = Display(basedump, mydump) if not args.watch: display.run_sync() else: if not args.make: yn = input("Warning: watch-mode (-w) enabled without auto-make (-m). You will have to run make manually. Ok? (Y/n) ") if yn.lower() == 'n': return if args.make: watch_sources = None if hasattr(diff_settings, "watch_sources_for_target"): watch_sources = diff_settings.watch_sources_for_target(make_target) watch_sources = watch_sources or source_directories if not watch_sources: fail("Missing source_directories config, don't know what to watch.") else: watch_sources = [make_target] q = queue.Queue() debounced_fs_watch(watch_sources, q, DEBOUNCE_DELAY) display.run_async(q) last_build = 0 try: while True: t = q.get() if t is None: break if t < last_build: continue last_build = time.time() if args.make: display.progress("Building...") ret = run_make(make_target, capture_output=True) if ret.returncode != 0: display.update(ret.stderr.decode() or ret.stdout.decode(), error=True) continue mydump = run_objdump(mycmd) display.update(mydump, error=False) except KeyboardInterrupt: display.terminate() main()