mirror of
https://git.mirrors.martin98.com/https://github.com/Ultimaker/Cura
synced 2025-04-22 13:49:39 +08:00
221 lines
8.0 KiB
Python
221 lines
8.0 KiB
Python
"""
|
|
STK500v2 protocol implementation for programming AVR chips.
|
|
The STK500v2 protocol is used by the ArduinoMega2560 and a few other Arduino platforms to load firmware.
|
|
This is a python 3 conversion of the code created by David Braam for the Cura project.
|
|
"""
|
|
import os
|
|
import struct
|
|
import sys
|
|
import time
|
|
|
|
from serial import Serial
|
|
from serial import SerialException
|
|
from serial import SerialTimeoutException
|
|
|
|
from . import ispBase, intelHex
|
|
|
|
|
|
class Stk500v2(ispBase.IspBase):
|
|
def __init__(self):
|
|
self.serial = None
|
|
self.seq = 1
|
|
self.last_addr = -1
|
|
self.progress_callback = None
|
|
|
|
def connect(self, port = "COM22", speed = 115200):
|
|
if self.serial is not None:
|
|
self.close()
|
|
try:
|
|
self.serial = Serial(str(port), speed, timeout=1, writeTimeout=10000)
|
|
except SerialException as e:
|
|
raise ispBase.IspError("Failed to open serial port")
|
|
except:
|
|
raise ispBase.IspError("Unexpected error while connecting to serial port:" + port + ":" + str(sys.exc_info()[0]))
|
|
self.seq = 1
|
|
|
|
#Reset the controller
|
|
for n in range(0, 2):
|
|
self.serial.setDTR(True)
|
|
time.sleep(0.1)
|
|
self.serial.setDTR(False)
|
|
time.sleep(0.1)
|
|
time.sleep(0.2)
|
|
|
|
self.serial.flushInput()
|
|
self.serial.flushOutput()
|
|
if self.sendMessage([0x10, 0xc8, 0x64, 0x19, 0x20, 0x00, 0x53, 0x03, 0xac, 0x53, 0x00, 0x00]) != [0x10, 0x00]:
|
|
self.close()
|
|
raise ispBase.IspError("Failed to enter programming mode")
|
|
|
|
self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])
|
|
if self.sendMessage([0xEE])[1] == 0x00:
|
|
self._has_checksum = True
|
|
else:
|
|
self._has_checksum = False
|
|
self.serial.timeout = 5
|
|
|
|
def close(self):
|
|
if self.serial is not None:
|
|
self.serial.close()
|
|
self.serial = None
|
|
|
|
#Leave ISP does not reset the serial port, only resets the device, and returns the serial port after disconnecting it from the programming interface.
|
|
# This allows you to use the serial port without opening it again.
|
|
def leaveISP(self):
|
|
if self.serial is not None:
|
|
if self.sendMessage([0x11]) != [0x11, 0x00]:
|
|
raise ispBase.IspError("Failed to leave programming mode")
|
|
ret = self.serial
|
|
self.serial = None
|
|
return ret
|
|
return None
|
|
|
|
def isConnected(self):
|
|
return self.serial is not None
|
|
|
|
def hasChecksumFunction(self):
|
|
return self._has_checksum
|
|
|
|
def sendISP(self, data):
|
|
recv = self.sendMessage([0x1D, 4, 4, 0, data[0], data[1], data[2], data[3]])
|
|
return recv[2:6]
|
|
|
|
def writeFlash(self, flash_data):
|
|
#Set load addr to 0, in case we have more then 64k flash we need to enable the address extension
|
|
page_size = self.chip["pageSize"] * 2
|
|
flash_size = page_size * self.chip["pageCount"]
|
|
print("Writing flash")
|
|
if flash_size > 0xFFFF:
|
|
self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])
|
|
else:
|
|
self.sendMessage([0x06, 0x00, 0x00, 0x00, 0x00])
|
|
load_count = (len(flash_data) + page_size - 1) / page_size
|
|
for i in range(0, int(load_count)):
|
|
recv = self.sendMessage([0x13, page_size >> 8, page_size & 0xFF, 0xc1, 0x0a, 0x40, 0x4c, 0x20, 0x00, 0x00] + flash_data[(i * page_size):(i * page_size + page_size)])
|
|
if self.progress_callback is not None:
|
|
if self._has_checksum:
|
|
self.progress_callback(i + 1, load_count)
|
|
else:
|
|
self.progress_callback(i + 1, load_count*2)
|
|
|
|
def verifyFlash(self, flash_data):
|
|
if self._has_checksum:
|
|
self.sendMessage([0x06, 0x00, (len(flash_data) >> 17) & 0xFF, (len(flash_data) >> 9) & 0xFF, (len(flash_data) >> 1) & 0xFF])
|
|
res = self.sendMessage([0xEE])
|
|
checksum_recv = res[2] | (res[3] << 8)
|
|
checksum = 0
|
|
for d in flash_data:
|
|
checksum += d
|
|
checksum &= 0xFFFF
|
|
if hex(checksum) != hex(checksum_recv):
|
|
raise ispBase.IspError("Verify checksum mismatch: 0x%x != 0x%x" % (checksum & 0xFFFF, checksum_recv))
|
|
else:
|
|
#Set load addr to 0, in case we have more then 64k flash we need to enable the address extension
|
|
flash_size = self.chip["pageSize"] * 2 * self.chip["pageCount"]
|
|
if flash_size > 0xFFFF:
|
|
self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])
|
|
else:
|
|
self.sendMessage([0x06, 0x00, 0x00, 0x00, 0x00])
|
|
|
|
load_count = (len(flash_data) + 0xFF) / 0x100
|
|
for i in range(0, int(load_count)):
|
|
recv = self.sendMessage([0x14, 0x01, 0x00, 0x20])[2:0x102]
|
|
if self.progress_callback is not None:
|
|
self.progress_callback(load_count + i + 1, load_count*2)
|
|
for j in range(0, 0x100):
|
|
if i * 0x100 + j < len(flash_data) and flash_data[i * 0x100 + j] != recv[j]:
|
|
raise ispBase.IspError("Verify error at: 0x%x" % (i * 0x100 + j))
|
|
|
|
def sendMessage(self, data):
|
|
message = struct.pack(">BBHB", 0x1B, self.seq, len(data), 0x0E)
|
|
for c in data:
|
|
message += struct.pack(">B", c)
|
|
checksum = 0
|
|
for c in message:
|
|
checksum ^= c
|
|
message += struct.pack(">B", checksum)
|
|
try:
|
|
self.serial.write(message)
|
|
self.serial.flush()
|
|
except SerialTimeoutException:
|
|
raise ispBase.IspError("Serial send timeout")
|
|
self.seq = (self.seq + 1) & 0xFF
|
|
return self.recvMessage()
|
|
|
|
def recvMessage(self):
|
|
state = "Start"
|
|
checksum = 0
|
|
while True:
|
|
s = self.serial.read()
|
|
if len(s) < 1:
|
|
raise ispBase.IspError("Timeout")
|
|
b = struct.unpack(">B", s)[0]
|
|
checksum ^= b
|
|
#print(hex(b))
|
|
if state == "Start":
|
|
if b == 0x1B:
|
|
state = "GetSeq"
|
|
checksum = 0x1B
|
|
elif state == "GetSeq":
|
|
state = "MsgSize1"
|
|
elif state == "MsgSize1":
|
|
msg_size = b << 8
|
|
state = "MsgSize2"
|
|
elif state == "MsgSize2":
|
|
msg_size |= b
|
|
state = "Token"
|
|
elif state == "Token":
|
|
if b != 0x0E:
|
|
state = "Start"
|
|
else:
|
|
state = "Data"
|
|
data = []
|
|
elif state == "Data":
|
|
data.append(b)
|
|
if len(data) == msg_size:
|
|
state = "Checksum"
|
|
elif state == "Checksum":
|
|
if checksum != 0:
|
|
state = "Start"
|
|
else:
|
|
return data
|
|
|
|
def portList():
|
|
ret = []
|
|
import _winreg
|
|
key=_winreg.OpenKey(_winreg.HKEY_LOCAL_MACHINE,"HARDWARE\\DEVICEMAP\\SERIALCOMM")
|
|
i=0
|
|
while True:
|
|
try:
|
|
values = _winreg.EnumValue(key, i)
|
|
except:
|
|
return ret
|
|
if "USBSER" in values[0]:
|
|
ret.append(values[1])
|
|
i+=1
|
|
return ret
|
|
|
|
def runProgrammer(port, filename):
|
|
""" Run an STK500v2 program on serial port 'port' and write 'filename' into flash. """
|
|
programmer = Stk500v2()
|
|
programmer.connect(port = port)
|
|
programmer.programChip(intelHex.readHex(filename))
|
|
programmer.close()
|
|
|
|
def main():
|
|
""" Entry point to call the stk500v2 programmer from the commandline. """
|
|
import threading
|
|
if sys.argv[1] == "AUTO":
|
|
print(portList())
|
|
for port in portList():
|
|
threading.Thread(target=runProgrammer, args=(port,sys.argv[2])).start()
|
|
time.sleep(5)
|
|
else:
|
|
programmer = Stk500v2()
|
|
programmer.connect(port = sys.argv[1])
|
|
programmer.programChip(intelHex.readHex(sys.argv[2]))
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|