What I Learned Building a Custom Async WebSocket Library in Rust

What I Learned Building a Custom Async WebSocket Library in Rust

Introduction

Software developers often ask themselves if they should engage in side projects, usually with the goal of enhancing their professional profile or personal branding.

The real question, though, is: What project should I create? Many developers fall into the trap of trying to build something entirely new and groundbreaking—an approach that, more often than not, isn't the best idea.

In my experience, it's perfectly valid to choose a simple topic or something already familiar in your daily work, and focus on solving a specific pain point for a particular group of people.

With this in mind, I reached a point in my career where I wanted to understand how things work at a deeper level. I wanted to explore how memory allocators function, how programming languages are built from scratch, and more. Around the same time, I was working with some asynchronous libraries in Rust (based on Tokio), and throughout my professional life, I often dealt with the WebSocket protocol.

That’s when the idea struck: Why not dive deeper into how asynchronous Rust works by building my own asynchronous WebSocket library? You might be thinking:

"Are you really going to reinvent the wheel?"

Well, sort of. There are already well-established WebSocket libraries in Rust, such as:

  • tungstenite

  • tokio-tungstenite

But I always felt that most of these libraries were accompanied only by examples and minimal documentation, leaving a gap for more comprehensive guides. So, I decided to write my own WebSocket async library from scratch for two main reasons:

  1. To gain a deep understanding of how async Rust works.

  2. To create more useful examples that people could easily apply in their own projects.

Now that you know the reasons behind this decision, let me take you through my 3-month journey of building the library and the challenges I encountered along the way.

By the way, if you're looking for project inspiration, check out this awesome resource: Awesome Rust.

Also, here is the link of the library, considering giving me a star!

socket-flow

Some Essential Concepts

Before diving into my learnings, let’s briefly cover some key concepts about asynchronous Rust.

Rust’s async model allows for writing efficient, non-blocking code, which is ideal for I/O-bound tasks such as networking, file handling, and WebSockets. To fully grasp Rust's async capabilities, you need to understand some core ideas, including how Tokio—the most popular async runtime in Rust—handles asynchronous tasks and concurrency. (Remember, concurrency is different from parallelism!)

1. The Foundation: async/await

In Rust, the async and await keywords allow you to write asynchronous code that looks similar to synchronous code, but with non-blocking behavior.

  • async fn: Declares an asynchronous function, meaning it returns a value that implements the Future trait.

      rustCopy codeasync fn example() -> u32 {
          42
      }
    
  • await: Suspends the execution of an async function until the future is ready (i.e., the result is available). This does not block the current thread but instead pauses execution, allowing other tasks to run.

      rustCopy codelet result = example().await; // Waits for the result asynchronously.
    

In essence, asynchronous code allows multiple tasks to "pause" while waiting for some I/O operation (like reading data from a WebSocket), enabling other tasks to run in the meantime.

2. Futures in Rust

Under the hood, all async functions return a Future. A future is essentially a value representing a computation that may complete in the future.

A basic future looks like this:

rustCopy codeuse std::future::Future;

fn create_future() -> impl Future<Output = u32> {
    async { 42 }
}

This future only runs when it’s polled, meaning Rust won’t execute it until it’s awaited. Tokio helps manage these futures by running them on its async runtime.

3. Why Tokio?

Rust's async ecosystem relies on async runtimes to execute asynchronous tasks, and Tokio is one of the most widely used runtimes. It provides several critical components:

  • Task scheduler: Runs futures concurrently, making sure that each task progresses as soon as it’s ready.

  • Async I/O support: Handles non-blocking I/O operations, like reading/writing from/to sockets or files.

  • Utilities: Provides timers, channels, and other helpful async tools.

I know this is still quite superficial, but we can’t use this entire article to talk only about this, maybe in the next one.

The Implementation

WebSockets is a communication protocol that provides full-duplex, bi-directional communication over a single, long-lived TCP connection. Unlike traditional HTTP, which is unidirectional and relies on a request/response model where the client must request data from the server, WebSockets enable real-time communication where both the client and server can send and receive messages independently, without the need for additional HTTP requests.

The full specification of the protocol and its attributes can be found in RFC 6455. It’s a widely used protocol in various industries that allows a persistent connection between a server and client, facilitating the continuous exchange of data between them.

Everything in WebSockets builds on top of TCP connections. And when we think of TCP connections, what comes to mind—aside from TLS? Sockets! A socket is an endpoint in a network communication link, used to send and receive data. Sockets operate at the transport layer, typically over protocols like TCP (Transmission Control Protocol) and UDP (User Datagram Protocol).

Sockets transmit data as streams of bytes between devices across a network. A stream of bytes is a sequence of data elements that are made available sequentially over time. Streams allow for a continuous flow of data, which is commonly used in network communication, file I/O, and other data-processing scenarios.

Once a client establishes a TCP connection with a server, a socket is created. You can think of this as a road where traffic can flow in both directions. This TCP connection (or socket) forms the foundation for WebSockets, as it provides the "road" on which bytes can travel. Like HTTP, WebSockets have standards that define how these bytes must be formatted in order to adhere to the protocol, as outlined in the RFC.

Here’s a quick overview of how the WebSocket protocol works:

  1. A TCP connection is established (with optional TLS/SSL for encryption).

  2. The client sends a handshake request to the server in the format of an HTTP request.

  3. The server reads the handshake request and, if it follows the correct format, upgrades the connection and sends an HTTP response back to the client.

  4. After the handshake, both client and server are free to communicate with each other. Each message they send is composed of one or more frames, which follow a specific structure.

  5. If a frame doesn't adhere to the standard structure, the connection may be dropped by either the server or the client.

In essence, it’s all about streams of bytes following a set of predetermined rules. One of the challenges I faced during development wasn’t necessarily adhering to the WebSocket standards outlined in the RFC, but rather writing code that is reusable, performant, and easy to read.

Moreover, working with bytes can be much more challenging than working with strings. Many times, I found myself wrestling with bitwise operations (like AND/OR/Shift) to ensure the correct values were aligned with how a WebSocket frame should be structured. Reconstructing a frame bit by bit required a deep understanding of these byte-level operations.

On top of that, Rust adds its own complexities, particularly with its ownership rules. Since I was working in an asynchronous environment, dealing with concurrent processes often required sharing references between them. This made things even more challenging, especially when trying to manage memory safely while still ensuring optimal performance.

The Frame

Just to give you a better understanding of the framing process in WebSockets, let's go over some key concepts related to frames.

Imagine you want to send the following JSON to the client:

{
  "_id": "670976a2e5c91a37916d9850",
  "index": 0,
  "guid": "d9c82ede-77f5-40d3-a2af-efcc890813bf",
  "isActive": true,
  "balance": "$3,510.82",
  "picture": "http://placehold.it/32x32"
}

A frame is the structure used to transmit data in the WebSocket protocol, whether you're the server or the client. Whenever you send data, your payload needs to adhere to a specific frame format. The message you want to send (such as the JSON above) can be transmitted in either a single frame or split across multiple frames. Let’s break down the frame structure:

WebSockets Demystified, Part 1: Understanding the Protocol | by Damiano  Magrini | Level Up Coding

As you may know, everything transmitted over network connections—or even when reading or writing local files—is sent as bits. A frame is simply a group of bits that follows a set of rules and contains the payload you want to send. Essentially, you're reading byte by byte from the stream and checking if the group of received bits/bytes conforms to the frame structure.

This diagram might look intimidating at first, but we can dissect each element of a frame:

  • FIN: The FIN bit indicates whether the message is fragmented into multiple frames or contained within a single frame.

  • RSV 1, 2, 3: These are reserved bits for more advanced logic, such as compression.

  • Opcode: This designates the type of the frame's payload and can represent options such as Continue, Text, Binary, Close, Ping, or Pong.

  • Mask: Indicates whether the payload is masked. Masking adds an extra layer of security to the payload content by performing a bitwise XOR operation between the masking key and each byte of the payload.

  • Payload Length: This group of bits specifies the length of the payload. It may vary depending on the size of the payload.

  • Masking Key: Represents the key used to mask the payload.

  • Payload: Finally, this is the actual data being transmitted.

A single WebSocket frame can theoretically send a payload of up to 264−12^{64}-1264−1 bytes. However, for larger payloads, it's often more efficient to split the data across multiple frames, as this can improve performance.

Here’s a function I wrote for reading frames from the stream after the connection has been established:

pub async fn read_frame(&mut self) -> Result<Frame, Error> {
    let mut header = [0u8; 2];

    self.buf_reader.read_exact(&mut header).await?;

    // The first bit in the first byte in the frame tells us whether the current frame is the final fragment of a message
    // here we are getting the native binary 0b10000000 and doing a bitwise AND operation
    let final_fragment = (header[0] & 0b10000000) != 0;
    // The opcode is the last 4 bits of the first byte in a websockets frame, here we are doing a bitwise AND operation & 0b00001111
    // to get the last 4 bits of the first byte
    let opcode = OpCode::from(header[0] & 0b00001111)?;

    // RSV is a short for "Reserved" fields, they are optional flags that aren't used by the
    // base websockets protocol, only if there is an extension of the protocol in use.
    // If these bits are received as non-zero in the absence of any defined extension, the connection
    // needs to fail, immediately
    let rsv1 = (header[0] & 0b01000000) != 0;
    let rsv2 = (header[0] & 0b00100000) != 0;
    let rsv3 = (header[0] & 0b00010000) != 0;

    if rsv1 || rsv2 || rsv3 {
        return Err(Error::RSVNotZero);
    }

    // As a rule in websockets protocol, if your opcode is a control opcode(ping,pong,close), your message can't be fragmented(split between multiple frames)
    if !final_fragment && opcode.is_control() {
        Err(Error::ControlFramesFragmented)?;
    }

    // According to the websocket protocol specification, the first bit of the second byte of each frame is the "Mask bit"
    // it tells us if the payload is masked or not
    let masked = (header[1] & 0b10000000) != 0;

    // In the second byte of a WebSocket frame, the first bit is used to represent the
    // Mask bit - which we discussed before - and the next 7 bits are used to represent the
    // payload length, or the size of the data being sent in the frame.
    let mut length = (header[1] & 0b01111111) as usize;

    // Control frames are only allowed to have a payload up to and including 125 octets
    if length > 125 && opcode.is_control() {
        Err(Error::ControlFramePayloadSize)?;
    }

    if length == 126 {
        let mut be_bytes = [0u8; 2];
        self.buf_reader.read_exact(&mut be_bytes).await?;
        length = u16::from_be_bytes(be_bytes) as usize;
    } else if length == 127 {
        let mut be_bytes = [0u8; 8];
        self.buf_reader.read_exact(&mut be_bytes).await?;
        length = u64::from_be_bytes(be_bytes) as usize;
    }

    if length > MAX_PAYLOAD_SIZE {
        Err(Error::PayloadSize)?;
    }

    // According to Websockets RFC, client should always send masked frames,
    // while frames sent from server to a client are not masked
    let mask = if masked {
        let mut mask = [0u8; 4];
        self.buf_reader.read_exact(&mut mask).await?;
        Some(mask)
    } else {
        None
    };

    let mut payload = vec![0u8; length];

    // Adding a timeout function from Tokio, to avoid malicious TCP connections, that passes through handshake
    // and starts to send invalid websockets frames to overload the socket
    // Since HTTP is an application protocol built on the top of TCP, a malicious TCP connection may send a string with the HTTP content in the
    // first connection, to simulate a handshake, and start sending huge payloads.
    let read_result = timeout(
        Duration::from_secs(5),
        self.buf_reader.read_exact(&mut payload),
    )
    .await;
    match read_result {
        Ok(Ok(_)) => {}        // Continue processing the payload
        Ok(Err(e)) => Err(e)?, // An error occurred while reading
        Err(_e) => Err(_e)?,   // Reading from the socket timed out
    }

    // Unmasking
    // According to the WebSocket protocol, all frames sent from the client to the server must be
    // masked by a four-byte value, which is often random. This "masking key" is part of the frame
    // along with the payload data and helps to prevent specific bytes from being discernible on the
    // network.
    // The mask is applied using a simple bitwise XOR operation. Each byte of the payload data
    // is XOR'd with the corresponding byte (modulo 4) of the 4-byte mask. The server then uses
    // the masking key to reverse the process, recovering the original data.
    if let Some(mask) = mask {
        for (i, byte) in payload.iter_mut().enumerate() {
            *byte ^= mask[i % 4];
        }
    }

    Ok(Frame {
        final_fragment,
        opcode,
        payload,
    })
}

If you take a closer look, you'll notice that we're performing several bitwise operations to extract each group of bits and ensure the data received from the stream adheres to the correct frame structure.

This small function took me a significant amount of time to write. There was a lot of back and forth to ensure the function was reliable and met all the criteria outlined in the RFC.

The First Library Modules

Now that we’ve discussed frames and some basic information about WebSocket connections, let's move on. If you think the first versions of my code were perfectly organized and bug-free, you're completely mistaken. Also, from here on, let’s refer to WebSockets as "WS" for simplicity.

The first functions I wrote were essentially to ensure that I could spawn a TCP server that could handle connections, with Postman acting as a WebSocket client to connect to my rudimentary server.

Since there are tools like Postman that allow connections to WS endpoints, I initially focused my efforts on building a function that could spawn a WS server.

Thus, my first function listened on a TCP port, accepted connections, and handled each connection asynchronously by passing it to the connection socket, where all the processing began. Here's a glimpse of that early, somewhat clunky code:

#[tokio::main]
pub async fn main() -> io::Result<()> {
    let listener = TcpListener::bind("127.0.0.1:9000").await?;

    loop {
        let (mut socket, _) = listener.accept().await?;
        tokio::spawn(async move {
            // WebSocket handshake
            const RESPONSE: &'static str = "\
            HTTP/1.1 101 Switching Protocols\r\n\
            Upgrade: websocket\r\n\
            Connection: Upgrade\r\n\
            Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\
            \r\n";

            let mut buf = vec![0; 1024];
            socket.read(&mut buf).await.unwrap();
            let sec_websocket_key = parse_websocket_key(&buf);
            let accept_value = generate_websocket_accept_value(&sec_websocket_key.unwrap());
            let response = format!(
                "HTTP/1.1 101 Switching Protocols\r\n\
            Connection: Upgrade\r\n\
            Upgrade: websocket\r\n\
            Sec-WebSocket-Accept: {}\r\n\
            \r\n", accept_value);
            socket.write_all(response.as_bytes()).await.unwrap();


            socket.read(&mut buf).await.unwrap();
            // if request.contains("Upgrade: websocket") {
            //     socket.write_all(RESPONSE.as_bytes()).await.unwrap();
            // }
        });
    }
}

As you can see, I was binding connections to port 9000, accepting the connection socket, and moving it to a separate spawned task. Since this was an asynchronous environment, I used Tokio tasks to process each connection concurrently.

When testing with Postman, it was already sending the handshake request to my server. I read the request through my socket, checked if it was a proper WebSocket handshake, and responded with the handshake reply.

If you'd like to see this code in full, you can check out this old commit, which marks the early days of the library: socket-flow - initial commit.

From that point onward, I started making the code more robust by implementing all the basic WebSocket RFC standards. Since the handshake process was complete, I moved on to create a function that constantly read frames and delivered them to the end user of the library.

Remember, the main goal of this library was to shield the end user from having to deal with complex stream configurations and all the internal protocol mechanics. To achieve this, I used Tokio channels to transmit incoming frames to the user. If the user wanted to send a message, they could also use the channel to transmit it.

The process looked like this:

  1. The socket from the connection performs the handshake with the client.

  2. Internally, the library splits the socket stream into a read half and a write half, creates the Tokio channels, and spawns another concurrent task to read any incoming frames and transmit them through the channel.

  3. The user interacts with a WSConnection struct that contains the Tokio channels. This allows the user to receive incoming messages through the rx (receive) channel and send messages through the tx (transmit) channel.

With this setup, the library handled all the protocol’s background processes concurrently, while the user only had to deal with the resulting messages.

It took countless hours to fix many bugs along the way, and thanks to Rust’s infamous borrow checker (which everyone loves to hate), I faced numerous ownership and value movement issues due to the asynchronous environment and the spawning of many tasks.

You can check out this stage of the code here: socket-flow - handshake stage.

The Client

After finishing the foundational code for the server, I moved on to creating the WebSocket client connection function. This largely reused many of the functions I had already written, with the main differences being in the handshake process and how opcodes, such as Close, are handled.

Security ?

Another important concern was ensuring that the library didn't introduce vulnerabilities. This prompted me to rewrite certain parts of the logic to prevent potential exploits.

A simple example is in the handshake request. When establishing a TCP connection between two entities, there may be a period of inactivity with no data being sent or received. Now, imagine not setting a timeout for the handshake request. An attacker could easily create a script that opens a TCP connection but never completes the handshake, causing the server to hang while waiting indefinitely for data. Without a timeout mechanism, this would leave the system susceptible to DoS/DDoS attacks.

That’s why I started improving the code by enforcing timeouts, requiring masking, and eventually adding TLS support to secure the communication.

Autobahn Tests

Finally, I used the Autobahn Test Suite, an essential tool for testing the conformance and interoperability of WebSocket implementations. The Autobahn suite ensures that both WebSocket clients and servers adhere to the WebSocket protocol as specified in RFC 6455. It helps verify that implementations correctly handle frames, manage connections, and follow the rules regarding payloads, opcodes, fragmentation, and more.

Running the Autobahn Test Suite was a game changer. It helped me uncover bugs and issues I hadn’t anticipated. I’m incredibly grateful for this tool because it provided clear insights into what needed improvement and what I had to fix.

In many ways, Autobahn became my personal TDD (Test-Driven Development) assistant. I ran it through my code, identified numerous broken test cases, and worked methodically to address each one.

To give you a clearer idea, Autobahn is a command-line tool that can test both server and client WebSocket implementations. For example, if you want to test your WebSocket server locally, you start your server and then run the Autobahn test suite. It performs a series of test cases and generates an HTML report showing the results.

Spectacular ! It’s a web page, where you can have all the details of the test cases, their inputs and outputs, and the reason why some of them failed.

Socket-flow has a CI/CD pipeline, for executing this tool every time there is a change in the code. It also uploads the results on a S3 bucket, so you can access the last test results through these URLs:

Some Tweaks for DX

Remember when we initially provided users with a WSConnection struct for each connection? This struct contained some Tokio channels for receiving and sending messages. While this worked, from a Developer Experience (DX) standpoint, it wasn’t ideal to require users to have prior knowledge of Tokio channels. Even though some developers might consider this knowledge essential, it still forced users to write extra logic to read from these channels, adding complexity.

Here’s an old piece of code where we were exposing these channels to the end-user:

use log::*;
use simple_websocket::handshake::perform_handshake;
use std::net::SocketAddr;
use tokio::net::{TcpListener, TcpStream};
use tokio::select;

async fn handle_connection(_: SocketAddr, stream: TcpStream) {
    match perform_handshake(stream).await {
        Ok(mut ws_connection) => loop {
            select! {
                Some(result) = ws_connection.read.recv() => {
                    match result {
                        Ok(message) => {
                            if ws_connection.send_data(message).await.is_err() {
                                eprintln!("Failed to send message");
                                break;
                            }
                        }
                        Err(err) => {
                            eprintln!("Received error from the stream: {}", err);
                            break;
                        }
                    }
                }
                else => break
            }
        },
        Err(err) => eprintln!("Error when performing handshake: {}", err),
    }
}

#[tokio::main]
async fn main() {
    env_logger::init();

    let addr = "127.0.0.1:9002";
    let listener = TcpListener::bind(&addr).await.expect("Can't listen");
    info!("Listening on: {}", addr);

    while let Ok((stream, _)) = listener.accept().await {
        let peer = stream
            .peer_addr()
            .expect("connected streams should have a peer address");
        info!("Peer address: {}", peer);

        tokio::spawn(handle_connection(peer, stream));
    }
}

You can see that we had a select! logic to handle the channels. While it wasn’t too cumbersome, there was room for improvement. I want to emphasize that this code snippet was part of an example, showing how a user could interact with the library.

Improving the Interface

To improve the user experience, I decided to implement the futures::Stream trait for the WSConnection struct. This way, users wouldn’t need to deal with channels directly, making the API more intuitive.

In Rust, futures::Stream is a trait for handling asynchronous collections or values that may not have arrived yet. It’s similar to the Iterator trait but designed for asynchronous operations. Instead of blocking while waiting for the next item, it returns a Future.

A Stream has an associated Item type, and calling the poll_next method on a stream will either return Poll::Pending if no items are ready, or Poll::Ready(Some(item)) when an item is available. If it returns Poll::Ready(None), it indicates that the stream has ended.

Here’s the trait implementation:

// WSConnection has the read_rx attribute, which is already a ReceiverStream
// Although, we don't want this attribute visible to the end-user.
// Therefore, implementing Stream for this struct is necessary, so end-user could
// invoke next() and other stream methods directly from a variable that holds this struct.
impl Stream for WSConnection {
    type Item = Result<Message, Error>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        // We need to get a mutable reference to the inner field
        let this = self.get_mut();

        // Delegate the polling to `read_rx`
        // We need to pin `read_rx` because its `poll_next` method requires the object to be pinned
        Pin::new(&mut this.read_rx).poll_next(cx)
    }
}

Finally, the end-user code:

use futures::StreamExt;
use log::*;
use socket_flow::handshake::accept_async;
use socket_flow::stream::SocketFlowStream;
use std::net::SocketAddr;
use tokio::net::{TcpListener, TcpStream};

async fn handle_connection(_: SocketAddr, stream: TcpStream) {
    match accept_async(SocketFlowStream::Plain(stream)).await {
        Ok(mut ws_connection) => {
            while let Some(result) = ws_connection.next().await {
                match result {
                    Ok(message) => {
                        if ws_connection.send_message(message).await.is_err() {
                            error!("Failed to send message");
                            break;
                        }
                    }
                    Err(e) => {
                        error!("Received error from the stream: {}", e);
                        break;
                    }
                }
            }
        }
        Err(err) => error!("Error when performing handshake: {}", err),
    }
}

#[tokio::main]
async fn main() {
    env_logger::init();

    let addr = "127.0.0.1:9002";
    let listener = TcpListener::bind(&addr).await.expect("Can't listen");
    info!("Listening on: {}", addr);

    while let Ok((stream, peer)) = listener.accept().await {
        info!("Peer address: {}", peer);
        tokio::spawn(handle_connection(peer, stream));
    }
}

We can clearly see how the code became much more concise, and now users don't have to deal directly with the channel logic. This made the interface more user-friendly and efficient! Another improvement was in error handling, where users could capture specific errors for each part of the library. Since we’re already talking about errors, let’s dive into that.

Error Handling

I made a strong effort to cover all possible errors for each function call, and one key tool that helped a lot was the thiserror Rust library. It provides a convenient derive macro for the standard library’s std::error::Error trait. This made life easier when creating custom errors, without the need to implement the traits manually for each one!

Final Considerations

There’s so much more to discuss about this library, but I’ll wrap it up here since this article is getting quite lengthy 😅. Reflecting on all the challenges I faced, the biggest insight I gained from this experience is that even though I’ve been working with Rust for the last 4 years, these past 3 months building this library have taught me more than the entire 4 years combined!

In our daily jobs, we may encounter tough features to implement or seemingly impossible bugs. However, over time, we tend to settle into a routine with how we code at our current company. That’s why so many people say that they learned the most in the early years of a new opportunity.

Personal projects like this one push us outside our comfort zones, making us tackle issues that only a small group of developers might have faced. It forces us to either become part of that group or... copy and paste solutions from them! 😅🤣

In any case, I believe projects like this are a way to level up your skills and open up new career opportunities, not to mention the potential for increasing your income! 🤑

Also, in the case you forgot, consider giving it a star!

socket-flow