A horizontal banner depicting a brilliant, bright sky with clouds.

marius.vision

Software Ghostbox Tanto B/X Toolkit
Media HeMakesMePlay A Dungeon About Blue Sky Github

Introducing Duck Unions: Python, Pydantic, and Discriminated Unions

Created: 2025-11-06 23:55:16 | Last Modified: 2025-11-07 00:07:29


I.

I really like sum types and matching patterns on them. For example:

class ConnectedClient:
    ip: Tuple[int, int, int, int]

class AuthenticatedClient:
    send: Callable[[str], None]
    auth: AuthToken

class Disconnected:
    pass

ClientConnection = ConnectedClient | AuthenticatedClient | Disconnected

This is very clean and straightforward. We can make it very clear that you're not supposed to send anything to a client that is merely connected, and, vice versa, that you no longer need to concern yourself with something like an IP address once a client is authenticated.

Also, I can take one look at the actual sum type definition of ClientConnection, and I know roughly what to expect from client connections. Yeah, they can only be connected, authenticated, or disconnected. It reads like poertry.

Since python's new destructuring match syntax, it's also no longer super painful to deal with sum types like this.

def handle_client_connection(client: ClientConnection, broadcast_messages: List[str]) -> Optional[ClientConnection]:
    match client:
        case ConnectedClient() as connected_client:
            return authenticate(connected_client)
        case AuthenticatedClient() as authenticated_client:
            for msg in broadcast_messages:
                authenticated_client.send(msg)
            return authenticated_client
        case Disconnected() as disconnected_client:
            return None
        case _ as unreachable:
            assert_never(unreachable)

Don't ask me what this code is supposed to do. Just imagine it's part of an IRC server. And don't ask me what IRC is either, it's what your grandpa used to pirate software.

Now for anxious paranoiacs like myself, code like the above is very soothing and relaxing. It's clear, it works, and it is, and this is essential, guaranteed to be working in the future. Specifically, the assert_never voodoo at the bottom guarantees that mypy will warn us if we ever add another concrete type to the sum type union of ClientConnection and forget to update the match patterns.

To really understand the appeal of sum types, you have to consider what equivalent code would look like with other paradigms. A naive approach might stuff everything into one big ClientConnection class, that has bool flags to show if it's authenticated or not, which then would conditionally enable or disable certain methods, like send. That's not great.

Another alternative is to use an abstract base class, and create an e.g. ClientConnectionInterface. This is slightly saner, but you still run into trouble. Is send an abstract method? If you say no, you now have to litter your code with if isinstance(con, Disconnected) checks etc. If you say yes, the send method basically has to do nothing for disconnected and unauthenticated clients . Two thirds of your interface implementors are now just cardboard decoys. Close but no cigar.

I mean it's not that hard. If you're going to model an actual disjunction in your problem domain, why not use, you guessed it, a disjunction to do it.

II.

So JSON is a thing. And serialization. And deserialization. I want to make my types into strings, and once they are strings, I want to stuff them into wall sockets and have them come out fine on the other side. The thing that makes this possible in python is called pydantic. You may have heard of it. It's kind of a big deal.

Only there's a problem. Consider JSON like this:

{
    "ip": [192, 168, 0, 1],
}

If we make our above types subclasses of pydantic's BaseModel, we can easily call ConnectedClient.model_validate_json on the dictionary above, and get a ConnectedClient instance out of it. With some additional effort and custom validators, we could probably even validate that the auth token in an AuthenticatedClient is legit. However, that only works if we already know that it's the JSON representing an AuthenticatedClient. In reality, we might just get random JSON strings coming in and need to turn them into some kind of connection.

We really want to call ClientConnection.model_validate_json(json_string). Like, really, really want to.

Only we can't. ClientConnection isn't a real boy. It's a sham. Just a type alias. You can't subclass it, you can't instantiate it, and it has no methods. It is the Waluigi of python types.

Also even if we could call ClientConnection.model_validate_json, which we can't, how would it know which type to construct? I mean, you can kind of tell with the types above, and indeed, pydantic is smart enough to discriminate in obvious cases, but what if two types are part of our union that are structurally identical?

The universal answer to this problem is called discriminated unions. It works by tagging each type with a unique field that discriminates it against all the other types in its union. We might call this field 'type'. It might look like this:

class ConnectedClient(BaseModel):
    type: Literal["ConnectedClient"] = "ConnectedClient"
    ip: Tuple[int, int, int, int]

We will get to how to tell pydantic that 'type' is a discriminator. First , I have to throw up into my mouth a little.

Yeah, I think it's ugly, and I have a problem with it. What happened to the zen of python? We are writing the class name like 3 times. What is this, java? Don't repeat yourself! Don't repeat yourself!

It's also kind of dangerous. We no longer have a single source of truth for the type name. Programmers might easily make spelling mistakes that won't get caught by static checks, and someone who isn't familiar with this nonsense looks at that and might just get confused. We need to do better.

Thing is, even if we get this right, Waluigi is still not a real boy. We will never be able to call ClientConnection.model_validate_json. Or can we?

III.

That was a rhetorical question, in case you were wondering, and the answer, is yes. Yes, we can, and being able to make Waluigi into a real boy is precisely what pydantic's TypeAdapter was made for.

So spoilers, we won't be able to get it perfect, but that's kind of the real zen of python, don't you think? Still, I will show you how to have discriminated unions, with abstract base classes and pydantic base models, that serialize and deserialize nicely, and that are minimally painful and confusing.

I'll spare you the didactics and the derivation. Here's the final, ultimate, batteries-included, working code example. Explanations follow below.

from typing import *
from pydantic import BaseModel, Field, TypeAdapter, model_validator
from abc import ABC, abstractmethod
import math
import json

class ShapeBase(BaseModel):

    @model_validator(mode="before")
    @classmethod
    def set_type_from_class_name(cls, data: Any) -> Any:
        if isinstance(data, dict) and "type" not in data:
            data["type"] = cls.__name__
        return data

    @abstractmethod
    def area(self) -> float:
        pass


class Circle(ShapeBase):
    type: Literal["Circle"]
    radius: float

    def area(self) -> float:
        return math.pi * (self.radius**2)


class Quadrilateral(ShapeBase):
    type: Literal["Quadrilateral"]
    width: float
    height: float

    def area(self) -> float:
        return self.width * self.height

Shape = Annotated[
    Union[Circle, Quadrilateral],
    Field(discriminator="type")
]

ShapeAdapter = TypeAdapter(Shape)

# For testing
unit_circle = Circle(radius=1)  # note no explicit mention of 'type'
quad_json = json.dumps({"type": "Quadrilateral", "width": 3.5, "height": 2.0})
invalid_quad_json = json.dumps({"width": 3.5, "height": 2.0})

After searching far and wide, this is the best version of a sum type in python (using pydantic) that I have come up with. Let's try it out before we talk about it.

from shapes import *
>>> unit_circle.model_dump_json()
'{"type":"Circle","radius":1.0}'
>>> ShapeAdapter.validate_json(quad_json)
Quadrilateral(type='Quadrilateral', width=3.5, height=2.0)

Great. The type fields, although they have no explicit default value, still get set to the correct string. What about the last example? Remember that this one is supposed to fail.

>>> ShapeAdapter.validate_json(invalid_quad_json)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/marius/prog/ai/ghostbox/genv/lib/python3.12/site-packages/pydantic/type_adapter.py", line 446, in validate_json
    return self.validator.validate_json(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
pydantic_core._pydantic_core.ValidationError: 1 validation error for tagged-union[Circle,Quadrilateral]
  Unable to extract tag using discriminator 'type' [type=union_tag_not_found, input_value={'width': 3.5, 'height': 2.0}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/union_tag_not_found
>>>

Excellent. Let's look at some of the above in more detail.

IV.

So to summarize and put a bow on it. If you want discriminated unions with pydantic and abstract base classes in python done right, this is the recipe. I lovingly call them 'Duck Unions', and they have 4 parts.

Assuming you want a sum type called 'Duck', you'll need:

There you have it. Finally Waluigi became a real boy.


Tags: programming