167 lines
5.0 KiB
Python
167 lines
5.0 KiB
Python
import time
|
||
import socket
|
||
import random
|
||
|
||
class Packet:
|
||
def __init__(self, data:str, seq_num:int) -> None:
|
||
self.data = data
|
||
self.seq_num = seq_num
|
||
|
||
|
||
class ApplicationLayer:
|
||
def __init__(self, data_len:int=5000) -> None:
|
||
self.data_len = data_len
|
||
self.data_to_send = ["data{:0>4d}".format(i) for i in range(data_len)]
|
||
|
||
|
||
class NetworkLayer:
|
||
def __init__(self, host:str, port:int) -> None:
|
||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
self.socket.bind((host, port))
|
||
self.socket.listen(1)
|
||
print("等待下层的不可靠传输连接。")
|
||
self.client_socket, address = self.socket.accept()
|
||
print("下层的不可靠传输连接成功。")
|
||
self.client_socket.setblocking(False)
|
||
|
||
def udt_send(self, data:str):
|
||
self.client_socket.send(data.encode())
|
||
|
||
def udt_rcv(self):
|
||
try:
|
||
return self.client_socket.recv(4).decode("utf-8")
|
||
except BlockingIOError:
|
||
return None
|
||
|
||
def close(self):
|
||
self.client_socket.close()
|
||
self.socket.close()
|
||
|
||
|
||
class Sender:
|
||
def __init__(
|
||
self,
|
||
window_size:int,
|
||
max_seq_num:int,
|
||
timeout_ms:2000,
|
||
networkLayer:NetworkLayer,
|
||
) -> None:
|
||
self.window_size = window_size
|
||
self.max_seq_num = max_seq_num
|
||
self.packet_list:list[Packet] = [None] * (self.max_seq_num + 1)
|
||
self.base_num = 1
|
||
self.next_seq_num = 1
|
||
self.timeout_ms = timeout_ms
|
||
self.networkLayer = networkLayer
|
||
self.timer = None
|
||
|
||
def rdt_send(self, data:str) -> bool:
|
||
if self.next_seq_num > max_seq_num:
|
||
return False
|
||
if self.next_seq_num >= self.base_num + self.window_size:
|
||
return False
|
||
|
||
self.packet_list[self.next_seq_num] = Packet(data, self.next_seq_num)
|
||
self.udt_send(
|
||
self.packet_list[self.next_seq_num].data,
|
||
self.packet_list[self.next_seq_num].seq_num
|
||
)
|
||
if self.base_num == self.next_seq_num:
|
||
self.timer = time.time()
|
||
self.next_seq_num += 1
|
||
return True
|
||
|
||
def rdt_rcv(self, ack_index:int):
|
||
print(f"收到ACK={ack_index},", end="")
|
||
if (ack_index < self.base_num):
|
||
print(f"(ACK={ack_index}) < (base={self.base_num}),ACK失效丢弃。")
|
||
return
|
||
self.base_num = ack_index + 1
|
||
self.timer = time.time()
|
||
if self.base_num == self.next_seq_num:
|
||
print(f"将base_num设置为下一个序列编号。")
|
||
self.timer = None
|
||
else:
|
||
print(f"将base_num设置为{self.packet_list[self.base_num].seq_num}。")
|
||
|
||
def udt_send(self, data:str, index:int):
|
||
index_data = '{:0>3d} '.format(index) + data
|
||
print(f"发送data=\"{index_data}\"", end="")
|
||
if random.random() > 0.25:
|
||
self.networkLayer.udt_send(index_data)
|
||
else:
|
||
print(",此包丢失。", end="")
|
||
print()
|
||
|
||
def is_timeout(self) -> bool:
|
||
if self.timer is None:
|
||
return False
|
||
return time.time() - self.timer >= 0.001 * self.timeout_ms
|
||
|
||
def gbn(self):
|
||
self.timer = time.time()
|
||
seq_index = self.base_num
|
||
while seq_index < self.next_seq_num:
|
||
self.udt_send(
|
||
self.packet_list[seq_index].data,
|
||
self.packet_list[seq_index].seq_num
|
||
)
|
||
seq_index += 1
|
||
|
||
def show_gbn(self) -> list[int]:
|
||
show = []
|
||
seq_index = self.base_num
|
||
while seq_index < self.next_seq_num:
|
||
show.append(self.packet_list[seq_index].seq_num)
|
||
seq_index += 1
|
||
return show
|
||
|
||
def get_ack_num(self, ack_str:str) -> int:
|
||
return int(ack_str)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
max_seq_num = 20
|
||
networkLayer = NetworkLayer(host="0.0.0.0", port=23666)
|
||
applicationLayer = ApplicationLayer(max_seq_num)
|
||
sender = Sender(
|
||
window_size=4,
|
||
max_seq_num=max_seq_num,
|
||
timeout_ms=2000,
|
||
networkLayer=networkLayer,
|
||
)
|
||
input("按回车键开始传输:")
|
||
|
||
pkg_list = applicationLayer.data_to_send
|
||
index = 1
|
||
while index <= max_seq_num:
|
||
time.sleep(1)
|
||
data = pkg_list[index - 1]
|
||
status = sender.rdt_send(data)
|
||
if status:
|
||
index += 1
|
||
|
||
ack_str = networkLayer.udt_rcv()
|
||
if ack_str is not None:
|
||
ack_num = sender.get_ack_num(ack_str)
|
||
sender.rdt_rcv(ack_num)
|
||
|
||
if sender.is_timeout():
|
||
print(f"超时。重传{sender.show_gbn()}")
|
||
sender.gbn()
|
||
|
||
while sender.base_num < sender.next_seq_num:
|
||
time.sleep(1)
|
||
ack_str = networkLayer.udt_rcv()
|
||
if ack_str:
|
||
ack_num = sender.get_ack_num(ack_str)
|
||
if ack_num is not None:
|
||
sender.rdt_rcv(ack_num)
|
||
|
||
if sender.is_timeout():
|
||
print(f"超时。重传{sender.show_gbn()}")
|
||
sender.gbn()
|
||
|
||
print("序列传输完成。")
|
||
networkLayer.close()
|