""" STK500v2 protocol implementation for programming AVR chips. The STK500v2 protocol is used by the ArduinoMega2560 and a few other Arduino platforms to load firmware. This is a python 3 conversion of the code created by David Braam for the Cura project. """ import os import struct import sys import time from serial import Serial from serial import SerialException from serial import SerialTimeoutException from . import ispBase, intelHex class Stk500v2(ispBase.IspBase): def __init__(self): self.serial = None self.seq = 1 self.last_addr = -1 self.progressCallback = None def connect(self, port = "COM22", speed = 115200): if self.serial is not None: self.close() try: self.serial = Serial(str(port), speed, timeout=1, writeTimeout=10000) except SerialException as e: raise ispBase.IspError("Failed to open serial port") except: raise ispBase.IspError("Unexpected error while connecting to serial port:" + port + ":" + str(sys.exc_info()[0])) self.seq = 1 #Reset the controller for n in range(0, 2): self.serial.setDTR(True) time.sleep(0.1) self.serial.setDTR(False) time.sleep(0.1) time.sleep(0.2) self.serial.flushInput() self.serial.flushOutput() if self.sendMessage([0x10, 0xc8, 0x64, 0x19, 0x20, 0x00, 0x53, 0x03, 0xac, 0x53, 0x00, 0x00]) != [0x10, 0x00]: self.close() raise ispBase.IspError("Failed to enter programming mode") self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00]) if self.sendMessage([0xEE])[1] == 0x00: self._has_checksum = True else: self._has_checksum = False self.serial.timeout = 5 def close(self): if self.serial is not None: self.serial.close() self.serial = None #Leave ISP does not reset the serial port, only resets the device, and returns the serial port after disconnecting it from the programming interface. # This allows you to use the serial port without opening it again. def leaveISP(self): if self.serial is not None: if self.sendMessage([0x11]) != [0x11, 0x00]: raise ispBase.IspError("Failed to leave programming mode") ret = self.serial self.serial = None return ret return None def isConnected(self): return self.serial is not None def hasChecksumFunction(self): return self._has_checksum def sendISP(self, data): recv = self.sendMessage([0x1D, 4, 4, 0, data[0], data[1], data[2], data[3]]) return recv[2:6] def writeFlash(self, flash_data): #Set load addr to 0, in case we have more then 64k flash we need to enable the address extension page_size = self.chip["pageSize"] * 2 flash_size = page_size * self.chip["pageCount"] print("Writing flash") if flash_size > 0xFFFF: self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00]) else: self.sendMessage([0x06, 0x00, 0x00, 0x00, 0x00]) load_count = (len(flash_data) + page_size - 1) / page_size for i in range(0, int(load_count)): recv = self.sendMessage([0x13, page_size >> 8, page_size & 0xFF, 0xc1, 0x0a, 0x40, 0x4c, 0x20, 0x00, 0x00] + flash_data[(i * page_size):(i * page_size + page_size)]) if self.progressCallback is not None: if self._has_checksum: self.progressCallback(i + 1, load_count) else: self.progressCallback(i + 1, load_count * 2) def verifyFlash(self, flash_data): if self._has_checksum: self.sendMessage([0x06, 0x00, (len(flash_data) >> 17) & 0xFF, (len(flash_data) >> 9) & 0xFF, (len(flash_data) >> 1) & 0xFF]) res = self.sendMessage([0xEE]) checksum_recv = res[2] | (res[3] << 8) checksum = 0 for d in flash_data: checksum += d checksum &= 0xFFFF if hex(checksum) != hex(checksum_recv): raise ispBase.IspError("Verify checksum mismatch: 0x%x != 0x%x" % (checksum & 0xFFFF, checksum_recv)) else: #Set load addr to 0, in case we have more then 64k flash we need to enable the address extension flash_size = self.chip["pageSize"] * 2 * self.chip["pageCount"] if flash_size > 0xFFFF: self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00]) else: self.sendMessage([0x06, 0x00, 0x00, 0x00, 0x00]) load_count = (len(flash_data) + 0xFF) / 0x100 for i in range(0, int(load_count)): recv = self.sendMessage([0x14, 0x01, 0x00, 0x20])[2:0x102] if self.progressCallback is not None: self.progressCallback(load_count + i + 1, load_count * 2) for j in range(0, 0x100): if i * 0x100 + j < len(flash_data) and flash_data[i * 0x100 + j] != recv[j]: raise ispBase.IspError("Verify error at: 0x%x" % (i * 0x100 + j)) def sendMessage(self, data): message = struct.pack(">BBHB", 0x1B, self.seq, len(data), 0x0E) for c in data: message += struct.pack(">B", c) checksum = 0 for c in message: checksum ^= c message += struct.pack(">B", checksum) try: self.serial.write(message) self.serial.flush() except SerialTimeoutException: raise ispBase.IspError("Serial send timeout") self.seq = (self.seq + 1) & 0xFF return self.recvMessage() def recvMessage(self): state = "Start" checksum = 0 while True: s = self.serial.read() if len(s) < 1: raise ispBase.IspError("Timeout") b = struct.unpack(">B", s)[0] checksum ^= b #print(hex(b)) if state == "Start": if b == 0x1B: state = "GetSeq" checksum = 0x1B elif state == "GetSeq": state = "MsgSize1" elif state == "MsgSize1": msg_size = b << 8 state = "MsgSize2" elif state == "MsgSize2": msg_size |= b state = "Token" elif state == "Token": if b != 0x0E: state = "Start" else: state = "Data" data = [] elif state == "Data": data.append(b) if len(data) == msg_size: state = "Checksum" elif state == "Checksum": if checksum != 0: state = "Start" else: return data def portList(): ret = [] import _winreg key=_winreg.OpenKey(_winreg.HKEY_LOCAL_MACHINE,"HARDWARE\\DEVICEMAP\\SERIALCOMM") i=0 while True: try: values = _winreg.EnumValue(key, i) except: return ret if "USBSER" in values[0]: ret.append(values[1]) i+=1 return ret def runProgrammer(port, filename): """ Run an STK500v2 program on serial port 'port' and write 'filename' into flash. """ programmer = Stk500v2() programmer.connect(port = port) programmer.programChip(intelHex.readHex(filename)) programmer.close() def main(): """ Entry point to call the stk500v2 programmer from the commandline. """ import threading if sys.argv[1] == "AUTO": print(portList()) for port in portList(): threading.Thread(target=runProgrammer, args=(port,sys.argv[2])).start() time.sleep(5) else: programmer = Stk500v2() programmer.connect(port = sys.argv[1]) programmer.programChip(intelHex.readHex(sys.argv[2])) sys.exit(1) if __name__ == "__main__": main()