import time import socket import random class Packet: def __init__(self, data:str, seq_num:int) -> None: self.data = data self.seq_num = seq_num class ApplicationLayer: def __init__(self, data_len:int=5000) -> None: self.data_len = data_len self.data_to_send = ["data{:0>4d}".format(i) for i in range(data_len)] class NetworkLayer: def __init__(self, host:str, port:int) -> None: self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.bind((host, port)) self.socket.listen(1) print("等待下层的不可靠传输连接。") self.client_socket, address = self.socket.accept() print("下层的不可靠传输连接成功。") self.client_socket.setblocking(False) def udt_send(self, data:str): self.client_socket.send(data.encode()) def udt_rcv(self): try: return self.client_socket.recv(4).decode("utf-8") except BlockingIOError: return None def close(self): self.client_socket.close() self.socket.close() class Sender: def __init__( self, window_size:int, max_seq_num:int, timeout_ms:2000, networkLayer:NetworkLayer, ) -> None: self.window_size = window_size self.max_seq_num = max_seq_num self.packet_list:list[Packet] = [None] * (self.max_seq_num + 1) self.base_num = 1 self.next_seq_num = 1 self.timeout_ms = timeout_ms self.networkLayer = networkLayer self.timer = None def rdt_send(self, data:str) -> bool: if self.next_seq_num > max_seq_num: return False if self.next_seq_num >= self.base_num + self.window_size: return False self.packet_list[self.next_seq_num] = Packet(data, self.next_seq_num) self.udt_send( self.packet_list[self.next_seq_num].data, self.packet_list[self.next_seq_num].seq_num ) if self.base_num == self.next_seq_num: self.timer = time.time() self.next_seq_num += 1 return True def rdt_rcv(self, ack_index:int): print(f"收到ACK={ack_index},", end="") if (ack_index < self.base_num): print(f"(ACK={ack_index}) < (base={self.base_num}),ACK失效丢弃。") return self.base_num = ack_index + 1 self.timer = time.time() if self.base_num == self.next_seq_num: print(f"将base_num设置为下一个序列编号。") self.timer = None else: print(f"将base_num设置为{self.packet_list[self.base_num].seq_num}。") def udt_send(self, data:str, index:int): index_data = '{:0>3d} '.format(index) + data print(f"发送data=\"{index_data}\"", end="") if random.random() > 0.25: self.networkLayer.udt_send(index_data) else: print(",此包丢失。", end="") print() def is_timeout(self) -> bool: if self.timer is None: return False return time.time() - self.timer >= 0.001 * self.timeout_ms def gbn(self): self.timer = time.time() seq_index = self.base_num while seq_index < self.next_seq_num: self.udt_send( self.packet_list[seq_index].data, self.packet_list[seq_index].seq_num ) seq_index += 1 def show_gbn(self) -> list[int]: show = [] seq_index = self.base_num while seq_index < self.next_seq_num: show.append(self.packet_list[seq_index].seq_num) seq_index += 1 return show def get_ack_num(self, ack_str:str) -> int: return int(ack_str) if __name__ == "__main__": max_seq_num = 20 networkLayer = NetworkLayer(host="0.0.0.0", port=23666) applicationLayer = ApplicationLayer(max_seq_num) sender = Sender( window_size=4, max_seq_num=max_seq_num, timeout_ms=2000, networkLayer=networkLayer, ) input("按回车键开始传输:") pkg_list = applicationLayer.data_to_send index = 1 while index <= max_seq_num: time.sleep(1) data = pkg_list[index - 1] status = sender.rdt_send(data) if status: index += 1 ack_str = networkLayer.udt_rcv() if ack_str is not None: ack_num = sender.get_ack_num(ack_str) sender.rdt_rcv(ack_num) if sender.is_timeout(): print(f"超时。重传{sender.show_gbn()}") sender.gbn() while sender.base_num < sender.next_seq_num: time.sleep(1) ack_str = networkLayer.udt_rcv() if ack_str: ack_num = sender.get_ack_num(ack_str) if ack_num is not None: sender.rdt_rcv(ack_num) if sender.is_timeout(): print(f"超时。重传{sender.show_gbn()}") sender.gbn() print("序列传输完成。") networkLayer.close()