import time
import socket
import random

class Packet:
    def __init__(self, data:str, seq_num:int) -> None:
        self.data = data
        self.seq_num = seq_num


class ApplicationLayer:
    def __init__(self, data_len:int=5000) -> None:
        self.data_len = data_len
        self.data_to_send = ["data{:0>4d}".format(i) for i in range(data_len)]


class NetworkLayer:
    def __init__(self, host:str, port:int) -> None:
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.socket.bind((host, port))
        self.socket.listen(1)
        print("等待下层的不可靠传输连接。")
        self.client_socket, address = self.socket.accept()
        print("下层的不可靠传输连接成功。")
        self.client_socket.setblocking(False)

    def udt_send(self, data:str):
        self.client_socket.send(data.encode())

    def udt_rcv(self):
        try:
            return self.client_socket.recv(4).decode("utf-8")
        except BlockingIOError:
            return None
        
    def close(self):
        self.client_socket.close()
        self.socket.close()


class Sender:
    def __init__(
            self, 
            window_size:int, 
            max_seq_num:int,
            timeout_ms:2000, 
            networkLayer:NetworkLayer,
        ) -> None:
        self.window_size = window_size
        self.max_seq_num = max_seq_num
        self.packet_list:list[Packet] = [None] * (self.max_seq_num + 1)
        self.base_num = 1
        self.next_seq_num = 1
        self.timeout_ms = timeout_ms
        self.networkLayer = networkLayer
        self.timer = None

    def rdt_send(self, data:str) -> bool:
        if self.next_seq_num > max_seq_num:
            return False
        if self.next_seq_num >= self.base_num + self.window_size:
            return False
        
        self.packet_list[self.next_seq_num] = Packet(data, self.next_seq_num)
        self.udt_send(
            self.packet_list[self.next_seq_num].data, 
            self.packet_list[self.next_seq_num].seq_num
        )
        if self.base_num == self.next_seq_num:
            self.timer = time.time()
        self.next_seq_num += 1
        return True
        
    def rdt_rcv(self, ack_index:int):
        print(f"收到ACK={ack_index},", end="")
        if (ack_index < self.base_num):
            print(f"(ACK={ack_index}) < (base={self.base_num}),ACK失效丢弃。")
            return
        self.base_num = ack_index + 1
        self.timer = time.time()
        if self.base_num == self.next_seq_num:
            print(f"将base_num设置为下一个序列编号。")
            self.timer = None
        else:
            print(f"将base_num设置为{self.packet_list[self.base_num].seq_num}。")
        
    def udt_send(self, data:str, index:int):
        index_data = '{:0>3d} '.format(index) + data
        print(f"发送data=\"{index_data}\"", end="")
        if random.random() > 0.25:
            self.networkLayer.udt_send(index_data)
        else:
            print(",此包丢失。", end="")
        print()

    def is_timeout(self) -> bool:
        if self.timer is None:
            return False
        return time.time() - self.timer >= 0.001 * self.timeout_ms
    
    def gbn(self):
        self.timer = time.time()
        seq_index = self.base_num
        while seq_index < self.next_seq_num:
            self.udt_send(
                self.packet_list[seq_index].data, 
                self.packet_list[seq_index].seq_num
            )
            seq_index += 1

    def show_gbn(self) -> list[int]:
        show = []
        seq_index = self.base_num
        while seq_index < self.next_seq_num:
            show.append(self.packet_list[seq_index].seq_num)
            seq_index += 1
        return show

    def get_ack_num(self, ack_str:str) -> int:
        return int(ack_str)
        

if __name__ == "__main__":
    max_seq_num = 20
    networkLayer = NetworkLayer(host="0.0.0.0", port=23666)
    applicationLayer = ApplicationLayer(max_seq_num)
    sender = Sender(
        window_size=4, 
        max_seq_num=max_seq_num,
        timeout_ms=2000, 
        networkLayer=networkLayer,
    )
    input("按回车键开始传输:")
  
    pkg_list = applicationLayer.data_to_send
    index = 1
    while index <= max_seq_num:
        time.sleep(1)
        data = pkg_list[index - 1]
        status = sender.rdt_send(data)
        if status:
            index += 1

        ack_str = networkLayer.udt_rcv()
        if ack_str is not None:
            ack_num = sender.get_ack_num(ack_str)
            sender.rdt_rcv(ack_num)

        if sender.is_timeout():
            print(f"超时。重传{sender.show_gbn()}")
            sender.gbn()

    while sender.base_num < sender.next_seq_num:
        time.sleep(1)
        ack_str = networkLayer.udt_rcv()
        if ack_str:
            ack_num = sender.get_ack_num(ack_str)
            if ack_num is not None:
                sender.rdt_rcv(ack_num)

        if sender.is_timeout():
            print(f"超时。重传{sender.show_gbn()}")
            sender.gbn()
    
    print("序列传输完成。")
    networkLayer.close()