To request addition or removal, please email sipb-www at mit.edu. Planet is updated every thirty minutes.
How should we think about error-handling in concurrent programs? In single-threaded programs, we’ve mostly converged on a standard pattern, with a diverse zoo of implementations and concrete patterns. When an error occurs, it is propagated up the stack until we find a stack frame which is prepared to handle it. As we do so, we unwind the stack frames in-order, giving each frame the opportunity to clean up or destroy resources as appropriate.
I've recently been doing a lot of both submitting and reviewing pull requests to PyTorch that were authored with substantial LLM assistance. This is a big difference from earlier this year, where it was clear LLMs worked well for greenfield projects but the code was too hopelessly sloppy for a production codebase. Here are my merged PRs that mention claude code in their description; Jason Ansel has also had a similar experience (Meta only link, here is the list of issues he referenced in his writeup). There already has been increasing discourse (Simon Willison, LLVM) on how code review should adapt to this new era of LLMs. My contribution to this discourse is this: within teams, code review should change to being primarily be a human alignment mechanism.
Here is a simple example: it is well known that LLMs are prone to generating overly defensive code: e.g., they will be constantly sprinkling try...catch everywhere or testing if a variable is some type when system invariants imply that it should always be that type. If someone sends me a PR with these problems, I am not commenting on these problems solely because I want them to be fixed. If that's all I cared about, I could have just fed my comments directly to claude code. The real problem is that the human who was operating the LLM didn't agree with me that this defensive code was bad, and the point of the review is to align them with me on what is overly defensive versus not. In the most trivial cases, maybe the engineer didn't read the LLM output, in which case the remedy is to make them actually read the code. But sometimes real human work has to happen; for example, maybe there is a global system invariant that one has to understand to know if the defensiveness is necessary or not. If we agree about the global system invariants, there's no reason the code review has to go through me: the original code author can just instruct the LLM to fix problems and keep me out of the loop until they have aligned the LLM output to themselves--at which point we should do the more expensive human to human alignment. The ideal is that I don't need to ever write review comments about mechanical problems, because they have already been fixed by the original author ahead of time.
Conversely, when I am putting up an LLM generated PR for human review, I am trying to transmit higher level information. How does the new code work? What do I need to know about the existing system to understand this code? This doesn't even have to be in the PR description: if the LLM proposes a fix that I myself don't understand, or seems difficult to understand, I will simply instruct it to try it a different way, until the resulting diff is obviously correct. Tokens are cheap: we should expect more out of the author of code, because the cost of generating these PRs has gone way down. Similarly, I am willing to throw out the code and start again; you don't have to feel bad about wasting my time (I didn't type it! I spent my time understanding the problem, and none of that is regretted.)
There is a lot of scaremongering about how engineers who don't pick up AI tools will be left behind. My take on this is that there a number of different skills that make up what it means to be a good software engineer, and it is clear that LLM coding, even today, is clearly reweighting the relative importance of these skills. I care a lot more about your ability to read code, reason about the big picture, communicate clearly and to have good taste, than I care about your ability to mechanically write code. There is an archetype of junior engineer who is not that good at coding but very good at the softer, higher level skills, and I think they will be very valuable in this new world order. Conversely, I think going forward I will have substantially less patience if I have to keep telling you the same things over and over, because I just don't value raw "ability to code" as much anymore. My ideal state is like that with long time senior teammates: I can trust that they have made good low level decisions, and I can focus on understanding the bigger picture and updating my mental model of how the system works.
Today's LLMs have no memory: they have to rediscover everything in the system from first principles every time they are run. The purpose of the humans, of the team, is to collectively maintain a shared vision of what, platonically, the system should do. I want code review to reconfigure itself around this purpose.
by Edward Z. Yang at December 20, 2025 10:48 PM
Famously, PyTorch and JAX don't agree on how shardings should be represented: PyTorch takes a mesh-dim oriented view, where for each dimension in your device mesh, you specify what sharding should be applied; JAX takes a tensor-dim oriented view, where for each dimension on your tensor, you say which mesh dimensions (potentially multiple!) shard it. Among my Twitter followers, it is generally agreed that the JAX formulation is more intuitive from a user perspective. OK, fine; if you prefer one representation over another, it's easy enough to translate between the two representations (in easy situations, at least!) In this post, I want to talk more about the framework implementation side: what is the better internal representation of sharding? I don't claim to have all the answers, but my motivation for writing this post is to help explain where I currently stand and how I evaluate proposals for evolving DTensor and sharding in PyTorch.
Closed versus open. I am going to make a precise technical claim: JAX sharding is closed, where as PyTorch sharding is (in principle) open. Here, what I mean by closed/open refers to the capability for users to extend a system: traditional ADTs are closed (you can't add another constructor to an ADT), whereas object-oriented classes are open (you can define a new subclass of a class). Now, technically JAX sharding is open: the jax.sharding.Sharding is a base class that is intended to be subclassed, but to do this you have to define things like _to_xla_hlo_sharding, which is as good as not being supported. The regular class everyone uses, NamedSharding, consists of a mesh and a tuple of mesh axes, with no obvious extension points. I also offer for the defense this unanswered forum post: https://github.com/jax-ml/jax/discussions/23703
In contrast, PyTorch sharding is in principle extensible: the sharding is expressed as a list of Placement, a class which is subclassed to define custom shardings. The extensibility of Placement isn't really well supported (for example, there's no way of conveniently adding extra rules for placements to sharding rules), but it works enough that both internally and externally there are implementations of weird placements (internally, StridedShard and NormPartial... and technically all of the non-sum reductions supported by Partial as well as uneven sharding; externally, see RaggedShard and InterleavedShard).
Why does mesh-dim oriented sharding support extensibility in this way? The key is that mesh-oriented sharding is very imperative in nature: you can think of the list of placements as a sequence of transformations you apply to the tensor from left-to-right. Concretely, given the current local tensor (as produced by all of the placements you handled for the mesh dims before the one you're currently processing), run an invertible function to split this tensor along the current mesh dimension. This gives you a bunch of new local tensors which you recursively continue sharding with the rest of the mesh dims. The invertibility of the function is the only real constraint on what function you can provide (since you need to be able to reassemble the shards back into the original full tensor), but otherwise your choice of function is unconstrained. It is in this sense that Placement is morally extensible.
When designing systems, it is not an unambiguous good to make the system more flexible. Closed systems like JAX's mean you don't have to worry about hilariously complicated situations like what if you unevenly shard on the same dimension multiple times (do you have any guarantees on the local sizes of tensors being somewhat balanced?) But sometimes, the use case demands a greater degree of expressivity (in the same way that manual memory management allows you to do more than you can conveniently do in a GC'ed language.)
How expressive does Sharding have to be? One of the primary value propositions of DTensor is that it specifies a standardized representation for saying how a tensor is sharded across your cluster. It's very good to have this information, because it prevents accidents, like forgetting that a tensor dimension is sharded so you do a reduction on that dimension without first doing a collective and you get subtly wrong results that take weeks to debug. It's better to have a system that is correct but slow, than it is to have a system that is fast but incorrect.
Being able to express all distributed states is not a terminal goal. There are lots of situations in distributed optimizations where you temporarily need to put the system in a state where it is very difficult to describe exactly how to interpret data across nodes. For example, when you implement ring attention, to avoid communications when performing softmax, you instead perform online softmax. It's quite difficult to say what the "placements" of the running quantities in online softmax are. In this case, we shouldn't overly stress ourselves with defining a placement: we should just use local_map or shard_map and absolve ourselves of needing to actually say exactly how data is laid out at any given point in time. But the key is that we should only do this in local regions of code; if we give up and local_map our entire model, we might as well have just not written our code with DTensor at all. So we should seek additional expressivity when it is needed to express how data is being communicated across system boundaries.
Here are some classic examples from LLM training where you need a little bit of extra expressivity, starting with simple cases and becoming more complicated:
I think it's a worthy use of complexity budget to search for a system design that can handle all of these things, especially since PyTorch's existing mesh-oriented sharding is already tantalizingly close to supporting this.
Why is adding a new Placement to PyTorch hard? I tend to think, fundamentally, that a mesh-oriented sharding strategy can support arbitrary Placement subclasses. So why does this not work so well in PyTorch? I think there really only are two issues:
I think with these two issues fixed, and a bit of careful design on what the overrideable API on Placement is for subclasses, I think we can have a very good extensibility story for shardings.
by Edward Z. Yang at December 08, 2025 05:35 AM
I have recently needed to draw the contents of high-dimensional (e.g., 4D and up) tensors where it is important to ensure that is clear how to identify each of the dimensions in the representation. Common strategies I've seen people do in this situation include printing a giant list 2D slices (what the default PyTorch printer will do) or flattening the Tensor in some way back down to a 2D tensor. However, if you have a lot of horizontal space, there is a strategy that I like that makes it easy to identify all the axes of the higher dimensional tensor: draw it as a matrix of matrices.
Here are some examples, including the easy up-to-2D cases for completeness.
0D: torch.arange(1).view()
0
1D: torch.arange(2)
0 1
2D: torch.arange(4).view(2, 2 )
0 1 2 3
3D: torch.arange(8).view(2, 2, 2)
0 1 4 5 2 3 6 7
4D: torch.arange(16).view(2, 2, 2, 2)
0 1 4 5 2 3 6 7 8 9 12 13 10 11 14 15
5D: torch.arange(32).view(2, 2, 2, 2, 2):
0 1 4 5 : 16 17 20 21
2 3 6 7 : 18 19 22 23
:
8 9 12 13 : 24 25 28 29
10 11 14 15 : 26 27 30 31
The idea is that every time you add a new dimension, you alternate between stacking the lower dimension matrices horizontally and vertically. You always stack horizontally before stacking vertically, to follow the standard row-major convention for printing in the 2D case. Dimensions always proceed along the x and y axis, but the higher dimensions (smaller dim numbers) involve skipping over blocks. For example, a "row" on dim 3 in the 4D tensor is [0, 1] but the "row" on dim 1 is [0, 4] (we skip over to the next block.) The fractal nature of the construction means we can keep repeating the process for as many dimensions as we like.
In fact, for the special case when every size in the tensor is 2, the generated sequence of indices form a Morton curve. But I don't call it that, since I couldn't find a popular name for the variation of the Morton curve where the radix of each digit in the coordinate representation can vary.
Knowledge check. For the 4D tensor of size (2, 2, 2, 2) arranged in this way, draw the line(s) that would split the tensor into the pieces that torch.split(x, 1, dim), for each possible dimension 0, 1, 2 and 3. Answer under the fold.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
dim=0
>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=0)]
[tensor([0, 1, 2, 3, 4, 5, 6, 7]), tensor([ 8, 9, 10, 11, 12, 13, 14, 15])]
0 1 4 5
2 3 6 7
----------------
8 9 12 13
10 11 14 15
dim=1
>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=1)]
[tensor([ 0, 1, 2, 3, 8, 9, 10, 11]), tensor([ 4, 5, 6, 7, 12, 13, 14, 15])]
0 1 | 4 5
2 3 | 6 7
|
8 9 | 12 13
10 11 | 14 15
dim=2
>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=2)]
[tensor([ 0, 1, 4, 5, 8, 9, 12, 13]), tensor([ 2, 3, 6, 7, 10, 11, 14, 15])]
0 1 4 5
------- -------
2 3 6 7
8 9 12 13
------- -------
10 11 14 15
dim=3
>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=3)]
[tensor([ 0, 2, 4, 6, 8, 10, 12, 14]), tensor([ 1, 3, 5, 7, 9, 11, 13, 15])]
0 | 1 4 | 5
2 | 3 6 | 7
8 | 9 12 | 13
10 | 11 14 | 15
by Edward Z. Yang at October 25, 2025 04:55 PM
For a while now, I’ve been fascinated by Z3 and by SMT solving more broadly. While on pat leave recently, I was reminded of the existence of regular-expression crossword puzzles, and allowed myself to get nerdsniped by writing a Z3-backed solver. I expected to spend perhaps an afternoon cranking out a quick solver; I ended up getting sucked into understanding and debugging Z3 performance, and learning far more about Z3 and about SMT than I expected.
With contributions from Richard Zou.
PT2’s dominant internal representation, FX graphs, do not directly support control flow (if statements, while loops): they only represent straight-line basic blocks. Most of our graph capture mechanisms are tracing based (fx.symbolic_trace, make_fx, Dynamo), which means that we expect to be able to linearize all conditionals we encounter into a straight line program. Sometimes, you want to work with code that has control flow while working the compiler stack. There is no silver bullet, instead there are a lot of different options with different tradeoffs.
We have a perfectly good general purpose language that supports control flow: Python. To handle control flow, compile only regions/submodules of your program that have no internal control flow, and then string them together with a standard Python control flow constructs. PT2 compiled regions are compositional with non-compiled regions, “it works.”
Pro:
Cons:
Link: Reducing torch.compile cold start compilation time with regional compilation
When the control flow is controlled by arguments that are known ahead of time (no data-dependent), you can also compile at the top level and get the flattened straight-line program for the particular branching you had in this case. Because Dynamo is a symbolic bytecode interpreter, it can automatically determine what inputs were used as part of control flow, and generate guards to validate that we would take the same paths again. If those values change, we will recompile the program at the new values. We dispatch between all the different unrollings of the program we have generated.
Pros:
Cons:
An FX graph just calls operators. The operator internally can have whatever control flow in them they want. So you can always black box a problematic region of your model into an operator and preserve compilation for everything else.
Pros:
Cons:
Do you really, really need a conditional? If you’re doing an if-branch, can you instead rewrite it so that you run both branches and torch.where dispatch to the results? If you’re doing a while-loop, can you unroll it to the max number of iterations and rely on dynamic shapes to cause it to no-op when you’re done and running extra iterations. Basically, this option is to rewrite your model so it doesn’t have Python-level control flow anymore (the conditional can either be done host or GPU side).
Pros:
Cons:
torch has special structured control flow operators that avoid unrolling large loops or needing to execute both branches of a control flow statement. If you’re familiar with JAX, these are very similar to the JAX equivalents. They have specific constraints that allow them to be directly compilable by torch.compile. For example, torch.cond accepts two functions (a true_fn and a false_fn) for the two branches and requires that outputs of each function must have the same properties (e.g. shape, dtype).
So far, we have the following “higher-order” operators (HOPs):
These are relatively new, have been used in torch.export for inference, but have not been battle tested for training or performance.
The semantics of these control flow operators are as follows:
def cond(pred, true_branch, false_branch, operands): if pred: return true_branch(*operands) else: return false_branch(*operands) def while_loop(cond_fn, body_fn, carried_inputs): val = carried_inputs while cond_fn(*val): val = body_fn(*val) return val def scan(combine_fn, init, xs, length=None): carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys)
Pros:
Cons:
If FX graphs give you basic blocks, you can use them as building blocks for a language that does support conditionals, stringing them together with basic blocks. In fact, Helion, a kernel DSL language, does exactly this, as it is common to need to directly write data-dependent conditionals and loops when writing kernels (it otherwise uses all PyTorch API functions, similar to conventional FX graphs). To do this, you would need to write your own Python frontend that parses Python directly to generate the CFG. TorchScript also does this, but TorchScript frontend is unmaintained and we don’t recommend using it (and it also doesn’t generate FX graphs by default.)
Pros:
Cons:
by Edward Z. Yang at September 05, 2025 02:01 PM
When training large scale LLMs, there is a large assortment of parallelization strategies which you can employ to scale your training runs to work on more GPUs. There are already a number of good resources for understanding how to parallelize your models: I particularly recommend How To Scale Your Model and The Ultra-Scale Playbook. The purpose of this blog post is to discuss parallelization strategies in a more schematic way by focusing only on how they affect your device mesh. The device mesh is an abstraction used by both PyTorch and JAX that takes your GPUs (however many of them you've got in your cluster!) and organizes them into a N-D tensor that expresses how the devices communicate with each other. When we parallelize computation, we shard a tensor along one dimension of the mesh, and then do collectives along that dimension when there are nontrivial dependencies between shards. Being able to explain why a device mesh is set up the way it is for a collection of parallelization strategies is a good check for seeing if you understand how the parallelization strategies work in the first place! (Credit: This post was influenced by Visualizing 6D Mesh Parallelism.)
tl;dr
Prologue: Why device mesh? Before we jump into the zoo, why do we have multi-dimensional meshes in the first place? One intuition is that the dimensions of the device mesh are a reflection of the physical constraints of networking between GPUs (there's a reason why all of the scaling books talk extensively about how the networking for GPUs works; you can't reason about what parallelization strategy you should use without knowing about this!) Let's imagine you have 1024 NVIDIA GPUs. You don't want to treat this 1024 GPUs as an undifferentiated blob of GPUs. Physically, these GPUs are grouped into nodes of eight which have much faster NVLink connections compared to cross-node communication which is done on a slower Infiniband connection. Intuitively, you will want to do something different depending on if you're doing intra-node communication or inter-node communication.
The device mesh imposes structure on this collection of GPUs. A mesh is typically specified as a tensor size (e.g., (128, 8)) as well as string axis names ala named tensor (e.g., ["dp", "tp"]), and is simply an N-D tensor over a range of GPU indices (typically [0, 1, 2, 3, ...] for GPUs, and a mostly ascending but occasionally permuted sequence for TPUs). We typically think of 2D and 3D tensors as grids and cubes, but I find it is more helpful (especially in higher dimensions) to think of the device mesh as imposing some self-similar (fractal) structure on the GPUs. In the simplest 2D mesh that accounts for intra versus inter node communication, GPUs are first organized into nodes on the inner-most dimension, and then the nodes are collected together in the outer-most dimension to form the cluster. (The self-similar nature of the nodes is important because it tells us how communication occurs across the cluster: to communicate over the outer-most mesh dimension, all the GPU 0s on each node talk to each other, all the GPU 1s, etc.) This is only the very simplest mesh we can create, however; with more complicated parallelization strategies we may impose extra levels of structure, e.g., we may organize nodes into pods of two and four, or we might further divide the eight GPUs of a single node. In other words, the mesh tells us about which GPUs communicate to which other GPUs. This is important to know, because when I want to parallelize our model, I am making choices about how to shard tensors across my GPUs. The mesh tells me which GPUs have the other shards of my tensor; in other words, they are who I have to communicate with when I am doing a computation that requires information about the full tensor and cannot be done with the local shards only.
In the zoo, when we talk about a parallelism strategy, we will talk to how it typically relates to other parallelization strategies in the model, and the device mesh will tell us if it is orthogonal to other parallelisms (a new dimension), multiplexed with another strategy (a reused dimension) or perhaps a completely different hierarchy of communication (multiple meshes in the same model that don't factor into the other).
Without further ado, here is the zoo!
Data parallelism (DP). Data parallelism predates the concept of device meshes, since you don't actually need any nontrivial mesh structure to do data parallelism: if you are only doing data parallel, you just shard your input on the batch axis for however many devices you have. This sharding propagates through forwards and backwards until you allreduce to compute the final global gradient for a parameter. If you did make a 1D device mesh (this is useful to think about, because most higher dimensional parallelisms will include some form of data parallelism), you'd probably name your mesh ["dp"], ["ddp"] or perhaps ["batch"].
Let's talk briefly about how people tend to name device mesh axes. In the PyTorch world, it's most common to name the axis after the parallelism that it is responsible, so either "dp" or "ddp" (you really shouldn't call it ddp, but the DataParallel taboo in PyTorch is very real!) The batch name is common in JAX, and is very natural there because when you annotate the sharding of your input, you need to say for each dimension tensor what mesh dim it is sharded over. So when you shard the batch dimension over the batch mesh dim, it looks just like you're labeling the batch dimension of your tensor as batch, e.g., P("batch", None). (This situation doesn't happen in PyTorch because shardings of a tensor are specified per device mesh dim, but that's a story for another day!)
Fully-sharded data parallel (FSDP). This is best understood as an augmentation over DP where weights are also sharded over all GPUs and you just all-gather weights before performing operations (and reduce-scatter in backwards). Because this all-gather is also among all devices, you don't need another axes in your mesh, and your mesh might also be called ["dp"] in this case, even though you're actually doing FSDP. Occasionally, you'll see people name their mesh ["fsdp"] in this case.
Hybrid sharded data parallel (HSDP). HSDP is an extension of FSDP where you shard weights (FSDP) up to the point where you can't actually do a giant all-gather/reduce-scatter over every GPU, and then replicate these shards to cover the rest of your cluster (DP). It's also amenable to fault tolerance techniques that make the modeling assumption that it's OK to lose samples of your batch if a replica fails (you won't model this with device mesh though!). This is probably the first time you will encounter a 2D device mesh (indeed, the DeviceMesh tutorial in PyTorch specifically uses hybrid sharding as its motivating example), since HSDP doesn't require any extra model changes on top of FSDP. There are a few common ways to name the mesh axes for HSDP. One way to think about it is that it is FSDP on the inner dimension and DP on the outer dimension, in which case you would say ["dp", "fsdp"]. Another way is to think about what happens to parameters at the various layers of the mesh: the inner dimension shards, while the outer dimension replicates, so you would say ["replicate", "shard"] or perhaps ["dp_replicate", "dp_shard"] to make it clear that you are still doing data parallelism across both of these device mesh dims (in particular, when you split your batches, you split on both the dp_replicate and dp_shard dims--although, to get the final gradients, you can do the reduction hierarchically by first doing a reduce-scatter on "dp_shard" and then doing an allreduce on "dp_replicate").
Tensor parallelism (TP). Depending on who you ask, tensor parallelism is either about letting you reduce your effective batch size for training or moving you towards reducing the memory usage of activations in your model. In the "reduce effective batch size" framing, the idea behind TP is that you can only scale up DP until your cluster is as large as your batch size. From a modeling perspective, it can be undesirable to have a batch size that is too large, so you can't just keep increasing your batch size to get more parallelism. Instead, TP allows us to get some extra scaling by sharding over the feature dimension of our matrix multiplies [1] (you can shard over either the columns or the rows of your weight matrix, so we will frequently specify if a TP Linear is column-wise or row-wise; in attention, column-wise linear effectively parallelizes the attention computation over attention heads). The communication needed to do TP is fairly exposed (unless you're doing async tensor parallel), so you typically want to keep the communications for it within a single node. This leads to this classic 2D device mesh for DP+TP: ["dp", "tp"] (or, if you're a JAXer, you might write ["batch", "model"], where model is used to indicate the inner feature dimension of the model weights being parallelized over.) When someone says 2D parallelism, they're usually referring to this combo of parallelisms (although I do not recommend using this term--as you can see, it is obviously ambiguous!) Note that tp is the inner mesh dimension, since it benefits the most from the high bandwidth network between GPUs on a single node.
You don't have to stop with DP+TP, however. If you're using FSDP with tensor parallelism (remember, "dp" can mean FSDP!), intra-node TP doesn't improve the amount of inter-node FSDP communication you have to do: however much TP you do, within one TP node you only have one slice of the model and have to talk to everyone else to get their slices. You could solve this by expanding TP to also cross nodes, but in practice mixed intra/inter-node collectives are a lot slower than pure inter-node collectives. This limits the scaling you can get from TP, and so if you're still hitting limits on FSDP, it can still be useful to apply HSDP to avoid running collectives that are too large. In that case, you'd end up with a mesh like ["dp_replicate", "dp_shard", "tp"].
Sequence parallelism (SP). For this section, we specifically take the definition of sequence parallelism from the Ultrascale Playbook (as distinguished from context parallelism). Although we said that TP is the first step towards reducing the memory usage of activations [2], if you literally implement DP+TP based on my descriptions above, you will still end up with more memory spent on activations than you want because there are still parts of the model around the FFN like the LayerNorm need the full hidden dimension to compute mean and variance [3]. To reduce the memory usage in these segments, you need to shard on something else. So typically what you will see is that the model will alternate between TP (hidden dimension is sharded) and SP (sequence dimension is sharded). Consequently, if you look at the device mesh for a model using DP+TP+SP, it will typically still look like ["dp", "tp"], and instead the tp dimension is multiplexed to be used both for TP and SP. Because TP and SP never occur at the same time, you don't need a separate dimension for them.
Ulysses sequence parallelism. Ulysses sequence parallelism from DeepSpeed Ulysses is another sequence parallelism strategy that is implemented by verl (because verl is forked so often, it shows up quite prominently if you are looking for examples of init_device_mesh on GitHub code search). It aims to alleviate memory pressure from extremely long sequences, so sequences are sharded on input, and only when attention needs to be computed is an alltoall issued to re-shard on the attention heads rather than the sequence (doing another alltoall to restore the sequence sharding after the attention is done). Importantly, this means it competes with TP for sharding on the attention heads, which is why you also see people use it to replace TP in MoE models, since it has much less communication than TP (at the cost of having to replicate the attention weights). In verl, you will just see a device mesh ["dp", "sp"] when you are using their FSDP backend (which is what supports Ulysses).
Context parallelism (CP). Context parallelism is another form of "sequence" parallelism. Like Ulysses sequence parallelism, sequences are sharded on input; the difference, however, is instead of using an alltoall to re-shard on attention heads, you just do a (distributed) attention on the entire context. You can do this the easy way by just using allgather to get the full context (as was done in llama4) or you can use a fancy kernel like ring attention, which carefully overlaps communication and computation when performing attention. A popular implementation of context parallelism lives in Megatron, which doesn't directly use PyTorch's native DeviceMesh abstraction but has an analogous HyperCommGrid. The mesh we see here will be something like ["dp", "cp"] or more commonly ["dp", "cp", "tp"]. Notice that we can have a dedicated mesh dim for CP: CP operates very similarly to SP outside of the attention calls (as it is just plain data parallelism when there is no cross-token dependency), but because it never shards on attention heads, it doesn't compete with TP and can be used completely orthogonally to TP (TP shards hidden, CP shards sequence).
CP has a pretty interesting interaction with FSDP. Both DP and CP shard the input data (on batch and sequence respectively). It's pretty common when you do FSDP to just shard over both "dp" ("dp_shard" in HSDP) and "cp". In torchtitan, we create a flattened mesh dim "dp_shard_cp" specifically for FSDP sharding (a flattened mesh dim is what happens if you take your mess and "forget" about some of the structure; e.g., if you were to do an all-gather, you just all-gather over all the flattened axes). In the HSDP world, "dp_cp" is still a useful concept because this is the combination of axes you want to all-reduce over to, e.g., compute the global average loss.
Pipeline parallelism (PP). Pipeline parallelism is kind of an ugly duckling and people tend to hate on it because you have to rewrite your models to introduce pipeline stages, and you can't really use things like DTensor with it (unless you do really strange things like how the GSPMD paper "supports" pipeline parallelism--the general consensus is automatic parallelism does not like PP). PP still goes in the device mesh, because it affects how you are organizing your GPUs, but, for example, torchtitan solely uses it to setup PGs for doing the point-to-point communications. I've seen both ["dp", "pp", ...] or ["pp", "dp", ...] for meshes with PP, but the order probably doesn't make too much of a difference as you are likely solidly inter-node at this point. Pipeline parallelism bandwidth use is very low, and latency can be covered up as you can immediately start processing the next batch after triggering an asynchronous send of the previous batch.
Expert parallelism (EP). EP is its own kettle of fish. Expert parallelism only applies over the expert computation of the model, but within this region, we are not sharding parameters as FSDP conventionally sees it: we will commonly have the entire expert's weights on our node. torchtitan's WIP expert parallelism implementation, when it has ALL parallelisms on, would look like ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], where dp_shard has been split into two mesh dimensions (DP shard modulo EP, and DP shard in EP). dp_shard_mod_ep is conventionally one, but when it is not it represents further FSDP-style sharding of expert weights inside of the expert region (there's some complication here if you have shared experts along-side your EP-sharded experts). But then dp_shard_in_ep, cp and optionally tp are combined together to give you the expert parallel dimension. It's actually more intuitive to imagine that you have two distinct meshes: ["pp", "dp_replicate", "dp_shard", "cp", "tp"] and ["pp", "dp_shard_mod_ep", "ep", "tp"]. The keen-eyed may also notice that there is no intrinsic reason the tp mesh size inside and outside of the expert parallel region, but this is not easily done if you have to have a single global device mesh for everything. In fact, there is a WIP PR to have two meshes, one for inside the expert region and one for outside: https://github.com/pytorch/torchtitan/pull/1660
Conclusion. The general concept behind mesh parallelism is that you can compose parallelization strategies without too much fuss. Indeed, the use of, e.g., TP to improve scaling is precisely because it lets you cover your device space without having to expand DP beyond the batch size you want to do. However, as you can see from these concrete examples, it's not always quite as simple as just stacking all of the parallelisms together one on top of each other. In the end, all the device mesh is doing is creating PGs behind groups of devices as defined by the mesh, so if you want some weird setup where you're swapping between two device meshes, PyTorch's general philosophy has been to say, have fun!
Thanks to Horace He, Tianyu Liu and Natalia Gimelshein for helping fact check this post. Any remaining errors are mine!
| [1] | One more subtlety I want to point out: while we tend to think of TP as sharding the feature dimension of parameters, when we "propagate" this sharding through the network, other intermediate tensors end up getting sharded on the TP dimension as well. In particular, in a transformer block, you will typically have a column-wise linear followed by a row-wise linear, and the intermediate activation will be temporarily sharded on the TP dimension before the row-wise linear runs. |
| [2] | I am very carefully using "activation memory" here and not total memory, because total memory usage (what you actually care about) is also a function of peak memory usage, which is subject to transient peaks such as when FSDP does an all-gather to collect parameters. In fact, even without SP, TP will improve your peak memory usage, because unlike FSDP, it's not necessary to all-gather the full weight matrix to actually perform the matrix multiply. TPs peak memory usage occurs when it all-gathers activations. |
| [3] | You will get a little improvement between the column-wise and row-wise linear, since the activations there are sharded. You can turn this into a big improvement by using selective activation checkpointing and forcing recomputation of activations that aren't sharded! (Plain activation checkpointing tends not to work so well because of the all-gather of the activations.) |
by Edward Z. Yang at August 31, 2025 03:20 AM
CuTe is a C++ library that aims to make dealing with complicated indexing easier. A key part of how it does this is by defining a Layout type, which specifies how to map from logical coordinates to physical locations (CuTe likes to say layouts are "functions from integers to integers.") In fact, CuTe layouts are a generalization of PyTorch strides, which say you always do this mapping by multiplying each coordinate with its respective stride and summing them together, e.g., i0 * s0 + i1 * s1 + .... Although NVIDIA's docs don't spell it out, the CuTe's generalization here is actually very natural, and in this blog post I'd like to explain how you could have invented it (on a good day).
First, a brief recap about strides. PyTorch views allow us to reinterpret the physical layout of a tensor in different ways, changing how we map logical coordinates into physical locations. For example, consider this 2-D tensor:
>>> torch.arange(4).view(2, 2)
tensor([[0, 1],
[2, 3]])
>>> torch.arange(4).view(2, 2).stride()
(2, 1)
The physical memory reads 0, 1, 2, 3, and if I want to know what the value at coordinate (0, 1) is (row 0, col 1), I compute 0 * 2 + 1 * 1, which tells me I should read out the value at index 1 in physical memory. If I change the strides, I can change the order I read out the physical locations. For example, if I transpose I have:
>>> torch.arange(4).view(2, 2).T
tensor([[0, 2],
[1, 3]])
>>> torch.arange(4).view(2, 2).T.stride()
(1, 2)
The physical memory hasn't changed, but now when we read out coordinate (0, 1), we compute 0 * 1 + 1 * 2, which tells me I should read the value at index 2 (which is indeed what I see at this coordinate!)
PyTorch also allows us to "flatten" dimensions of a tensor, treating them as a 1D tensor. Intuitively, a 2-D tensor flattened into a 1-D one involves just concatenating all the rows together into one line:
>>> torch.arange(4).view(2, 2).view(-1) tensor([0, 1, 2, 3])
We should be able to do this for the transpose too, getting tensor([0, 2, 1, 3]), but instead, this is what you get:
>>> torch.arange(4).view(2, 2).T.view(-1) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
The dreaded "use reshape instead" error! The error is unavoidable under PyTorch striding: there is no stride we can select that will cause us to read the elements in this order (0, 2, 1, 3); after all, i0 * s0 is a pretty simple equation, we can't simultaneously have 1 * s0 == 2 and 2 * s0 == 1.
Upon learning this, an understandable reaction is to just shrug, assume that this is impossible to fix, and move on with your life. But today, you are especially annoyed by this problem, because you were only trying to flatten N batch dimensions into a single batch dimension so that you could pass it through a function that only works with one batch dimension, with the plan of unflattening it when you're done. It doesn't matter that this particular layout is inexpressible with strides; you aren't going to rely on the layout in any nontrivial way, you just care that you can flatten and then unflatten back to the original layout.
Imagine we're dealing with a tensor of size (2, 2, 2) where the strides for dim 0 and dim 1 were transposed as (2, 4, 1). It should be OK to flatten this into a tensor (4, 2) and then unflatten it back to (2, 2, 2). Intuitively, I'd like to "remember" what the original sizes and strides are, so that I can go back to them. Here's an idea: let's just store the original size/stride as a nested entry in our size tuple. So instead of the size (4, 2), we have ((2, 2), 2); and now analogously the stride can simply be ((2, 4), 1). When I write (2, 2) as the "size" of a dimension, I really just mean the product 4, but there is some internal structure that affects how I should index its inside, namely, the strides (2, 4). If I ask for the row at index 2, I first have to translate this 1D coordinate into a 2D coordinate (1, 0), and then apply the strides to it like before.
Well, it turns out, this is exactly how CuTe layouts work! In CuTe, sizes/strides are hierarchical: a size is actually a tree of ints, where the hierarchy denotes internal structure of a dimension that you can address linearly (in fact, everything by default can be addressed in a 1-D linear way, even if its an N-D object.) The documentation of Layout does say this... but I actually suffered a lot extracting out the high level intuition of this blog post, because CuTe uses co-lexicographic ordering when linearizing (it iterates over coordinates (0,0), (1,0), (2,0), etc. rather than in the more normal lexicographic order (0,0), (0,1), (0,2)). This leads to some truly deranged example code where they print a 2D matrix in conventional lexicographic ordering, and then turn around and say, "But wait, if I have the layout take care of translating the 1D coordinate into an ND coordinate, it is colexicographic!!":
> print2D(s2xh4) 0 2 1 3 4 6 5 7 # sure, why not? > print1D(s2xh4) 0 4 2 6 1 5 3 7 # wtf???
In any case, if you want to engage with the documentation, s2xh4 is the important example to pay attention to for understanding the nested semantics. However, note the example is smeared across like five sections and also you need to know about the co-lexicographic thing to understand why the examples print the way they do.
by Edward Z. Yang at August 22, 2025 06:48 AM
The purpose of this post is to sum up, in one place, the state of torch.compile for training as of August 2025. Nothing in here isn't something you might not already know about from elsewhere on the Internet, but we rarely put everything together in one place. The target audience for this document are teams who are evaluating the use of torch.compile for large scale training runs.
First, the basics. torch.compile (also known as PT2) is a compiler for PyTorch eager programs for both inference and training workloads. Speedups from 1.5-2x compared to eager code are typical, and torch.compile also makes it possible to do global optimizations for memory (e.g., automatic activation checkpointing) and distributed communications (e.g., async tensor parallelism).
The headline functionality of torch.compile is a decorator you can attach to a function to compile it:
@torch.compile()
def f(x, y):
...
Here are some non-functional properties of compile which are important to know:
For large scale training runs, torch.compile faces stiff competition from (1) PyTorch native distributed frameworks which embrace eager mode and implement all optimizations by hand (e.g., megatron), (2) custom "compiler" stacks which reuse our tracing mechanisms (e.g., symbolic_trace and make_fx) but implement their desired passes by hand, (3) JAX, which has always been XLA first and is years ahead in compile-driven parallelism techniques.
Here is where we currently are for advanced parallelism (with an emphasis on comparing with JAX):
Functional collectives. If you don't like DTensor, we also support "functional collectives", which are non-mutating versions of collective operations that can be used to manually implement SPMD operations in a compiler-friendly way without needing DTensor. (In fact, if you use traditional collective APIs and compile them, we will silently translate them into functional collectives for compiler passes.) When compiled, functional collectives don't necessarily force allocation of the output buffer as they can be re-inplaced. Importantly, functional collectives currently do NOT support autograd, see https://discuss.pytorch.org/t/supporting-autograd-for-collectives/219430
Not SPMD compiler by default. torch.compile does not assume the program being compiled is SPMD by default, which means it will not do things like drop unused collectives (you can change this behavior with a config flag). Additionally, the default mode of use for torch.compile is to compile in parallel on all nodes, which means care has to be taken to ensure that every instance of the compiler compiles identically (only one rank recompiling, or compilers making different decisions, can lead to NCCL timeout). We ultimately think that we should compile a program once and send it to all nodes, but as this is not currently implemented, the general approach people have taken to solve this problem is to either (1) eliminate all sources of divergent behavior from ranks, e.g., don't allow the compiler to look at the actual size for dynamic inputs when making compiler decisions, or (2) introducing extra collectives to the compiler to communicate decisions that must be made consistently across all ranks.
Our vision for the future of advanced parallelism, spearheaded by the in-progress SimpleFSDP and AutoParallel, is that users should write single-node programs that express mathematically what they want to do. These are then transformed into efficient distributed programs in two steps: (1) first, collectives are inserted into the graph in a naive way (i.e., simply to express what the sharding of all intermediates should be), and (2) the collectives are optimized to handle scheduling concerns such as pre-fetching and bucketing. AutoParallel sets a GSPMD style goal of automatically determining a good enough sharding for a program--it should be able to rediscover data parallel, tensor parallel, even expert parallel(!)--but SimpleFSDP sets a smaller goal of just inserting collectives in the pattern that FSDP would mandate, and then writing FSDP-specific optimization passes for recovering FSDP2's performance. It is very common to write domain specific optimizations; for example, async tensor parallelism is also implemented as a pass that detects TP patterns and rewriting them to async TP operations. Unlike JAX, which started with a very generic solver and has needed to add more manual escape hatches over time, PyTorch has started with writing all of the distributed patterns exactly by hand, and we are only recently adding more automatic mechanisms as an alternative to doing everything by hand.
torch.compile performs many optimizations, but here are some particularly important ones to know about:
torch.compile is a just-in-time compiler and as such, in its default configuration, compilation will occur on your GPU cluster (preventing you from using the GPUs to do other useful work!) In general, most pathological compile times arise from repeated recompilation (often due to dynamic shapes, but sometimes not). In Transformer models, compile time can also be improved by only compiling the Transformer block (which can then be compiled only once, instead of having to be compiled N times for each Transformer block in the model).
We don't think caching is an ideal long-term solution for large scale training runs, and we have been working on precompile to solve the gap here. Precompile simply means having compilation be an ahead-of-time process which produces a binary which you can directly run from your training script to get the compiled model. The compilation products are built on top of our ABI stable interface (developed for AOTInductor) which allows the same binaries to target multiple PyTorch versions, even though PyTorch the library does not offer ABI compatibility from version to version.
The most typical pattern we see for people who want to make use of torch.compile for large-scale training is to fork torchtitan and use this codebase as the basis for your training stack. torchtitan showcases PyTorch native functionality, including torch.compile--in effect, it shows you how to use features in PyTorch together in a way that lets you do large-scale training. From there, swap out the components you are opinionated about and keep the things you don't care about.
by Edward Z. Yang at August 14, 2025 02:33 AM
While investigating the performance of the new Python 3.14 tail-calling interpreter, I learned (via this very informative comment from Sam Gross) new (to me) piece of performance trivia: Modern CPUs mostly no longer struggle to predict the bytecode-dispatch indirect jump inside a “conventional” bytecode interpreter loop. In steady-state, assuming the bytecode itself is reasonable stable, modern CPUs achieve very high accuracy predicting the dispatch, even for “vanilla” while / switch-style interpreter loops1!
A lot of strong engineers that I know haven't really taken a serious look at AI coding; they've used LLMs to ask questions or write simple scripts and appreciate that it is a useful tool, but haven't actually tried building a nontrivial application entirely from scratch in vibe coding style (here, I use the term in its original meaning: when you do AI coding without carefully reviewing the output). This is understandable: if you're not working on a green field project, there aren't that many opportunities to write code in this style--standard practice for established projects is that someone else needs to review all of the code you write: this is a bad match for vibe coding! So in this post, I want to give a concrete case study of a nontrivial system that was entirely vibe coded (ScubaDuck), to argue the following claims:
Update: You can see all of my prompts and the resulting agent trajectories at scubaduck-prompts.
ScubaDuck is a discount implementation of Meta's internal Scuba realtime database system. You can read more about what exactly this is on GitHub, but it's not so important for the purposes of this post: the key details you need to know about ScubaDuck is that it consists of a Python server that exposes an API to perform queries against a DuckDB database, and an HTML and JavaScript frontend application which implements the forms for building these queries and rendering of the output data. Both the forms and output data rendering have nontrivial JavaScript enhancements: some form inputs are chip inputs and support autocomplete, and the time series view is an SVG chart. All of these components were coded from scratch, so the project has no third-party JavaScript dependencies.
So on the one hand, this project is pretty simple. There are no stringent performance or uptime requirements, it's a pretty standard server-client program that the LLM has seen millions of times before (this is good!) On the other hand, the exact behavior of the frontend UI is quite intricate and would be very difficult to one-shot in a single prompt. Indeed, as I was coding and testing the application, I frequently ran into situations that I didn't anticipate in my original specification, and that I had to ask Codex to refine. Another way to put it is that ScubaDuck is a relatively simple functional specification (although this too was not one shot), but I did a lot of polishing of small behaviors so that the interface behaved in the way that I expected Scuba to behave. Here, it was helpful that I had a very clear idea of what I wanted (since I've used Scuba quite a lot at work).
Going into ScubaDuck, I had a pretty good sense that this project should be a good fit for LLMs. HTML, JavaScript and Python are all extremely high resource languages, and I'd heard lots of people raving about how good LLMs were at transforming wireframes and mockups into fully functional websites. It is also fully self contained and straightforward-ish to test (only "ish" because you do have to use something like Playwright to actually test the frontend UI, which honestly is a slog. But fortunately, the LLM can write the tests for you!) One design decision I made, which I didn't originally anticipate but worked out in the end, was the decision to not use any third-party JavaScript libraries. This was by accident: Python has no native of bundling third party JavaScript, but I wanted the tool to work offline. I wasn't sure if you could vibe code an SVG charting library from scratch, but apparently you can and it's quite easy!
ScubaDuck was implemented with OpenAI Codex in the cloud (not the CLI tool). Codex's cloud offering requires you to initialize a hermetic environment which the coding agent can execute commands in. It's pretty well known now that AI coding agents work much better if they are able to run the code they write and see if it worked or not, so this is quite an important part of the process. Unfortunately, this was somewhat time consuming trial and error to setup. I had a fairly detailed initial prompt, and what I would do was submit it to Codex, watch it fail, read over the trajectory (the agent logs) to see what happened (Codex wanted to use npm! Codex couldn't download something from the internet! Codex tried to use a package that wasn't available!) and then fixed whatever environment misconfiguration had caused it to fail, or edited AGENTS.md to instruct it to not do some behavior. According to my history, the first day of the project was spent unsuccessfully trying to get the project setup, and my first successful Codex PR only happened on May 19.
At the end of setup, I had the following:
After I got my initial prompt to generate a first draft of the application, I was able to begin vibe coding in earnest.
The basic vibe coding loop works like this:
For example, after the very first PR, some very mild poking around immediately revealed the bugs fixed in #2:
There's a race condition in the current test logic for matching against table contents in run_query. Specifically, if there were previously valid results in lastResults, and for some reason Dive doesn't do anything, then we will still see the old results. The testing framework should explicitly clear lastResults before attempting an interaction.
...and #3:
Filter functionality does not work. We will first add a failing test, and then fix it. The failing test should click "Add Filter", then select "user" as the field, and then add an "alice" chip (by typing alice in the text box and pressing ENTER). Then when we dive, we should see two alice rows. Right now, NO request is issued at all when we click Dive. Diagnose and then fix the problem.
Prompt the agent to write tests. It's very helpful to prompt the agent to generate tests for whatever bugs its fixing. For frontend code, I decided to use playwright to write these tests. An example in #11:
def test_header_and_tabs(page: Any, server_url: str) -> None:
page.goto(server_url)
page.wait_for_selector("#order_by option", state="attached")
header = page.text_content("#header")
assert "sample.csv" in header
assert "events" in header
assert page.is_visible("#settings")
assert page.is_hidden("#columns")
page.click("text=Columns")
assert page.is_visible("#columns")
cols = page.locator("#column_list li").all_inner_texts()
assert "timestamp" in cols
assert "event" in cols
page.click("text=View Settings")
assert page.is_visible("#settings")
btn_color = page.evaluate("getComputedStyle(document.querySelector('#dive')).backgroundColor")
assert "rgb(0, 128, 0)" == btn_color
sidebar_overflow = page.evaluate("getComputedStyle(document.querySelector('#sidebar')).overflowY")
view_overflow = page.evaluate("getComputedStyle(document.querySelector('#view')).overflowY")
assert sidebar_overflow == 'auto'
assert view_overflow == 'auto'
This is kind of a very manual and somewhat brittle test. I doubt I would have had the patience to manually write and debug this. But the LLM will happily do it. There is also a fear that the test might not actually be testing anything for real. But because these tests are all generated from bug reports, you can prompt the model to write the test first, check that it's failing, and then fix the problem. I didn't instruct this in the above prompt, but I did instruct it for #155, for example. The trajectory shows that the LLM demonstrates it can repro the bug before fixing the issue:
I’ll add a case-insensitive test for "samples" to check server behavior with lowercase input.
(scubaduck) root@5b69743466e4:/workspace/scubaduck# python - <<'EOF' > import json > from scubaduck import server > > app = server.app > client = app.test_client() > payload = { > "table": "events", > "start": "2024-01-01 00:00:00", > "end": "2024-01-03 00:00:00", > "graph_type": "table", > "order_by": "samples", # lowercase > "order_dir": "DESC", > "limit": 10, > "columns": [], > "group_by": ["user"], > } > rv = client.post("/api/query", data=json.dumps(payload), content_type="application/json") > print('status', rv.status_code) > print('data', rv.get_json()) > EOF status 400 data {'error': 'Unknown column: samples'} (scubaduck) root@5b69743466e4:/workspace/scubaduck#I’ve found the bug happens when "Samples" is lowercase. I’ll adjust the query to handle case insensitivity and add a test for it. Time to modify server.py to fix the issue!
Beware BC-breaking changes. The LLM generated tests worked well when I didn't need to change behavior in a BC-breaking way. Example prompt:
Change time column to default to none unconditionally. (I want to default it to a time column if we find a good candidate but this threshold needs tuning and I don't have the tuning right now.)
This hit a Codex timeout:
This attempt to update the default time column led to extensive frontend test failures that require deeper changes. The environment made it difficult to fully diagnose the Playwright test behaviors within the time available, so the work could not be completed.
In standard software engineering practice, when this happens, decouple the BC compatible and BC breaking changes!
Make it so that Time Column can be set to (none). When it is set this way, the Start/End fields are hidden and we don't apply a filter on time range. (#115)
and then later instead of defaulting the time column to none, I added a heuristic to pick a column that looked like time, which picked the same column that all of the existing tests had also expected to be called with.
Refactors have to be split up. Codex's timeout means that you can't ask it to do too much in one go. Here's a prompt that timed out:
scubaduck/index.html has gotten a bit long. Let's split out some of the JS code into dedicated JS files for their functionality. Also setup the necessary Flask scaffolding to serve these JS files. I think splitting out these specific components would be good:
- Dropdown implementation
- Sidebar resizing
- JS controlling the View Settings (e.g., updateDisplayTypeUI, as well as one off interactions on form elements, columns handling, filter handling, the actual Dive implementation (including query updating), reading in defaults from query string)
- Table rendering (e.g., formatNumber, sorting)
- Chip input implementation
- Chart rendering (showTimeSeries)
Make changes to AGENTS.md or README.md describing the structure so you can quickly find where the components you need are
I eventually did manage the refactor by prompting Codex to individually move out the pieces I wanted to extract one-by-one. This is a place where I think Claude Code probably would have performed better.
Parallelizing tasks. As you can see from the lengths of my prompts, it does take a while to write a good prompt; you're basically writing a bug report with enough detail that the LLM can repro it and then fix it. So sometimes I would be bottlenecked on prompt writing. However, sometimes the prompts were quite short. In those cases, Codex encourages you to submit more tasks that can run in parallel. I found this worked well, and I'd sometimes have as many as five instances going (once again, rate limited by discovering problems, making designs and typing prompts!) One irritation is when the tasks end up conflicting with each other. Sometimes the conflicts are easy to fix, but if it feels nontrivial, it's often better to just ask Codex to redo one of the PRs on latest main after the other has landed. To avoid merge conflicts, it helps to have only one "main feature" agent going at any time, and then ask the agent to do random bugfixes in parallel with it. Once you have no more tasks to get running, you can go do something else while you wait for the agents to finish (manager schedule!)
As a reminder, I've posted all of my prompts (including the ones that failed) at scubaduck-prompts, and I think it's helpful to skim through them to get a flavor of what I was asking the LLM. But to summarize, what did I spend most of my time on prompting Codex to do? My general vibe (ahem) is that I spent most of my time doing minor enhancements, where I instructed Codex to make some part of the program work slightly differently, in a way that was previously unspecified from the previous prompt. The metaphor I had in my head while I was working on the project was like that of a sculptor chiseling away marble: in the beginning, anything is possible, but as I kept prompting, I continuously narrowed down the space of possible programs I had until I had exactly the one I wanted. One big thing I want to note is that Codex rarely needed to make updates to my tests; for the most part, tests that were added never got taken away, because I never "changed my mind". I suspect that the vibe coding process would have been rockier if I was having to change behavior frequently.
One of the things that surprised me the most about the process was how easy it was to implement a line chart in SVG with Codex. My first prompt resulted in a chart that looked broken on the test data:
We're going to add a new View type, to go along with Samples and Table: Time Series. Time Series supports all the fields that Table supports, and a few more:
- X-axis: Main group by dimension, e.g., the x-axis on time series view. This is our custom dropdown selector, but only time columns are populated here. It should prefer a default setting from the following list, most preferred first: "time", "timestamp"
- Granularity: Choose the time interval between data points on the chart. For example, a granularity of 1 hour means there will be a data point every 60 minutes that is aggregated with the chosen Aggregate function over the data for the granularity period before point. This is a plain drop down. The valid values are: Auto, Fine, 1 second, 5 seconds, 10 seconds, 30 seconds, 1 minute, 4 minutes, 5 minutes, 10 minutes, 15 minutes, 30 minutes, 1 hour, 3 hours, 6 hours, 1 day, 1 week, 30 days. The semantics of the Auto setting is that it sets the interval to whatever would result in maximum 100 buckets (if there are not enough data points for that many buckets, it just picks the finest time interval that makes sense), and Fine which sets the interval to 500 buckets.
- Fill Missing Buckets: This is a dropdown. For now, it has the settings "Fill with 0 (Per Series)" (default), "Connect (Per Series)" and "Leave blank".
Additionally, the default setting of Limit is 7, as it controls how many elements from group by will be plotted (the actual number of lines plotted could be a multiple of this, as we will plot every selected Column).
Unlike Samples and Table, we will instead display a line chart in the right panel. To plot the line chart, we will implement it by hand with JS and SVG, similar to how highcharts implements it. We will not use any third party dependencies. Lines will be plotted as paths, no smoothing, no dots for individual data points. Each series (as generated by group by) should be plotted with a different color, assigned using a best practices color palette for graph design. There should be a rendering of x-axis and y-axis; the x-axis should have slanted labels to aid readability. When we mouse over the chart, a vertical line should snap to the center of the time bucket that we are closest to. We should also display a crosshair on all of the series showing us their values at that data point, and highlight the closest point we are on, and increase the thickness of the series that point is on. To the left of the graph (still in the right panel), there should be a legend. The legend looks like this:
[GROUP BY VALUE] [AGGREGATE] [First Column name, with series color] [Number of samples for the first column] [Second Column name, with series color] [Number of samples for the second column] ... for all columns ---- ... for all group by values (up to the limit)So for example, if I group by user, I might see:
Alice AVG value 4 (samples)The highlighted series (which has a thicker line) should also be highlighted in the legend).
This was kind of terrifying, because I initially thought I didn't have a good way to test the SVG outputs. But after doing some regular old-fashioned debugging and reading the code (yes, this part not vibe coded), I figured out the problem, and also realized that Playwright can test that an SVG path is not just entirely straight. After the initial bugs were fixed, I mostly had to add missing features like x-axis/y-axis and interactivity features (amusingly, Codex ignored most of the instructions in the latter half of the prompt, giving only the barest bones legend. I suspect this was because I had some files which were too long). My general take after this was that JS chart libraries are going to become obsolete: it's much easier to vibe code a bespoke implementation and then customize the heck out of it.
ScubaDuck was implemented in about 150 Codex prompts. As you can see from the sample prompts above, the prompts are recognizably programming, they just happen to be in plain English language. This is a big help, because I never had to keep track of the nest of callbacks and state machines for implementing complex UI elements in JavaScript. I had to be fluent in what I wanted my program to do, and a good QA tester for the application to discover new problems that needed to be fixed, but I did not have to worry at all about the vagaries of SVG DOM elements or pixel position computation minutiae. It's hard to say how long it would have taken to code this by hand, but I think reproducing a UI that's been in production for years at Meta in three (part-time) days is pretty good!
Despite having done a bit of AI coding before, I also learned a bit from working on Codex. Codex made it blindingly clear that the parallel modality (and subsequent conflict resolution) is important. It made me adjust up my estimation of the capability of LLMs to write raw HTML/JS and evoked a future where people vibe code components in place of taking on a third party dependency. I was very appreciative of no rate limit Codex (though I doubt it's going to last.) It also reminded me how difficult it will be to setup agent environments for "real" projects (like PyTorch).
Hopefully, this case study has given you some ideas for things to try. Go forth and vibe code, responsibly!
by Edward Z. Yang at June 02, 2025 04:31 AM
I’ve been playing around with Nix and NixOS for the past week and honestly, I don’t really like NixOS. But, now I dislike other OSes even more since I finally understand what they’re all missing compared to NixOS!
Anyways, as you might know, this site is hosted on SIPB’s XVM service and uses Arch Linux. Since NixOS totally ruined Arch and every other Linux distro for me, I decided to switch the server for this site to using NixOS instead.
by Anthony Wang at April 14, 2025 11:07 PM
Do you use an LLM for coding? Do you maintain a personal benchmark based on problems you have posed the LLM? The purpose of this blog post is to convince you should do this: that you can do so with marginal effort on top of your day-to-day vibe coding and that you will get both short and long term benefits from making your own personal benchmark exist.
I started thinking about benchmarks for coding in part with my frustration with the discourse around LLMs in the public squares I frequent (Reddit and Twitter). People often want to know "what's the best model" or "what's the best coding IDE"? One might imagine that the way to answer this question would be to test the models on a variety of problems from real world uses of the LLM for coding, and then compare how well various systems do on this. Indeed, whenever a new SOTA model releases, the lab will usually tell you about the model's performance against a few well known coding benchmarks. Problem solved?

Of course not! In fact, for the most part, no one really talks about benchmarks when comparing models. Why? I argue the most popular benchmarks measure tasks that are largely different from what a user wants out of an LLM. For example, take the recent Gemini 2.5 Pro release. In their headline table, they test against LiveCodeBench, Aider Polyglot and SWE-bench Verified. Both LiveCodeBench and Aider Polyglot derive their problems from contest programming and pedagogical exercises (respectively), while SWE-bench assesses bug fixes to preexisting codebases. While useful, this is only a small slice things people want to do with LLMs.
Wouldn't it be great if you had your own, personal benchmark, based on problems you actually care about? If you are tweaking your .cursorrules, you could run your benchmark to see if a change you made helped or not. When a new model comes out, you could spend a few bucks to run your eval and make a decision if you should switch your daily driver. And then on social media, if you wanted to stan the new model, instead of asking the model to drop a ball inside a rotating hexagon or vagueposting about how the new model is incredible, you could just post your benchmark results.
Nicholas Carlini's Yet Another Applied LLM Benchmark is an existence proof that this playbook can work. As Nicholas describes it:
It's a collection of nearly 100 tests I've extracted from my actual conversation history with various LLMs.
There are two defining features of this benchmark that make it interesting. Most importantly, I've implemented a simple dataflow domain specific language to make it easy for me (or anyone else!) to add new tests that realistically evaluate model capabilities. This DSL allows for specifying both how the question should be asked and also how the answer should be evaluated. Most questions are evaluated by actually running the code the model writes but the framework supports a bunch of other evaluation methods as well. And then, directly as a result of this, I've written nearly 100 tests for different situations I've actually encountered when working with LLMs as assistants.
I have been working on my own benchmark based off of Carlini's benchmark, and I can confirm that this works well for the traditional style of coding eval, where you have a one-shot task that generates and executes the code against some test cases. My basic strategy is to vibe code as usual, but whenever I give an LLM a task that it isn't able to one shot, I consider adding it to the benchmark. In more detail:
For example, the other day I needed to take an asciinema recording and convert it into a sequence of frames rendered as plain text. However, the only project for doing these conversations was agg, which converts recordings into animated gifs. In agg_to_text, I ask an LLM to take agg's source code and create a new program which dumps the frames as plain text rather than gif images. The reason why this task is difficult, is because there is some discretion in deciding when to emit a frame, and with my original prompt the LLM didn't precisely replicate the original behavior in agg. While working on the benchmark, I realized that instructing the model specifically about how frame batching worked was enough to get it to preserve the original behavior. But I don't think I should need to do this: thus this task. (P.S. If this test saturates, well, I can always make it harder by removing the agg source code from the prompt.)
The ability to benchmark one shot tasks is here today, but I would like to speculate a bit about what lies beyond them. In particular, most of my LLM coding activity involves asking the LLM to make changes to a pre-existing project, which makes it less amenable to "single prompt creates self contained program". (Also, I usually only ask one-shot questions that the LLM can answer, so most of them would never go in my benchmark.)
In short, how can I extract tasks from my day-to-day work? There seems to be two big extra levers we have:
I have started adapting Carlini's framework to work better for these cases, although I would love to be told someone has already solved this problem for me. In particular, I am very excited about using transcript tasks to evaluate whether or not things I add to my prompts / triggered rules are helping or not. Current SOTA model instruction following isn't great and I regularly catch models doing behaviors that I explicitly told them not to in the system prompt. I have started some initial analysis over all of my chat logs to find cases where the model misbehaved, although I haven't quite worked out how I want to build an eval out of it.
One word of warning: to make transcript tasks, you need an AI coding system that doesn't obscure how it assembles its underlying prompts (which rules out most of the popular closed source AI code editors.)
I started building evals for a selfish reason: I wanted to be able to tell if modifications to my prompts were doing anything. But I also think there is a broader opportunity that arises if we also publish these benchmarks to the world.
For one, building a real world benchmark on use cases we care about is a way to communicate to the people training AI models whether or not they are doing well or not. Historical evals have focused on LeetCoding, and consequently we have models that would ace any big tech interview and yet on real world tasks will drive you off a cliff at the first opportunity. And this is not just free labor for the top labs: if you believe in open source models, one of the biggest barriers to good small models is having really high quality data. We, the OSS vibe coding community, can directly help here.
I think there is a tremendous opportunity for the open source community to really push the state of the art in coding evaluations. There's only so many benchmarks that I, personally, can create, but if everyone is making benchmarks I could eventually imagine a universe of benchmarks where you could curate the problems that are relevant to your work and quickly and cheaply judge models in this way: a Wikipedia of Coding Benchmarks.
To summarize: every time an LLM fails to solve a problem you ask it for, this is a potential new benchmark. As long as there is a way to automate testing if the LLM has solved the problem, you can turn this into a benchmark. Do this for yourself, and you can quickly have a personal benchmark with which to evaluate new models. Do this at scale, and you can help push the frontier in coding models.
by Edward Z. Yang at April 04, 2025 07:05 AM
About a month ago, the CPython project merged a new implementation strategy for their bytecode interpreter. The initial headline results were very impressive, showing a 10-15% performance improvement on average across a wide range of benchmarks across a variety of platforms. Unfortunately, as I will document in this post, these impressive performance gains turned out to be primarily due to inadvertently working around a regression in LLVM 19. When benchmarked against a better baseline (such GCC, clang-18, or LLVM 19 with certain tuning flags), the performance gain drops to 1-5% or so depending on the exact setup.
Earlier this month, I used Claude to port (parts of) an Emacs package into Rust, shrinking the execution time by a factor of 1000 or more (in one concrete case: from 90s to about 15ms). This is a variety of yak-shave that I do somewhat routinely, both professionally and in service of my personal computing environment. However, this time, Claude was able to execute substantially the entire project under my supervision without me writing almost-any lines of code, speeding up the project substantially compared to doing it by hand.
In my previous two posts "Ways to use torch.compile" and "Ways to use torch.export", I often said that PyTorch would be good for a use case, but there might be some downsides. Some of the downsides are foundational and difficult to remove. But some... just seem like a little something is missing from PyTorch. In this post, here are some things I hope we will end up shipping in 2025!
A programming model for PT2. A programming model is a an abstract description of the system that is both simple (so anyone can understand it and keep it in their head all at once) and can be used to predict the system's behavior. The torch.export programming model is an example of such a description. Beyond export, we would like to help users understand why all aspects of PT2 behave the way it does (e.g., via improved error messages), and give simple, predictable tools for working around problems when they arise. The programming model helps us clearly define the intrinsic complexity of our compiler, which we must educate users about. This is a big effort involving many folks on the PyTorch team and I hope we can share more about this effort soon.
Pre-compilation: beyond single graph export. Whenever someone realizes that torch.compile compilation is taking a substantial amount of time on expensive cluster machines, the first thing they ask is, "Why don't we just compile it in advance?" To support precompiling the torch.compile API exactly as is not so easy; unlike a traditional compiler which gets the source program directly as input, users of torch.compile must actually run their Python program to hit the regions of code that are intended to be compiled. Nor can these regions be trivially enumerated and then compiled: not only must know all the metadata input tensors flowing into a region, a user might not even know what the compiled graphs are if a model has graph breaks.
OK, but why not just run the model, dump all the compiled products, and then reuse them later? This works! Here is a POC from Nikita Shulga where a special decorator aot_compile_sticky_cache swaps between exporting a graph and running the exported product. Zhengxu Chen used a similar idea to export Whisper as a few distinct graphs, which he then manually stitched together in C++ to get a Python-free version of Whisper. If you want training to work, you can more directly integrate AOTInductor as an Inductor backend, e.g., as seen in this POC.. We are a stones throw away from working precompilation, which can guarantee no compilation at runtime, we just need to put the pieces together!
Improving caching further. There are some gaps with caching which we hope to address in the near future: (1) loading Triton cache artifacts takes a long time because we still re-parse the Triton code before doing a cache lookup (James Wu is on this), (2) if you have a lot of small graphs, remote cache ends up having to do lots of small network requests, instead of one batched network request at the beginning (Oguz Ulgen recently landed this), (3) AOTAutograd cache is not fully rolled out yet (James Wu again). These collectively should be worth a 2x speedup or even more on warm cache time.
Fix multithreading. We should just make sure multithreading works, doing the testing and fiddly thread safety auditing needed to make it work. Here's a list of multithreading related issues.
Draft mode export. Export requires a lot of upfront work to even get an exported artifact in the first place. Draft mode export capitalizes on the idea that it's OK to generate an unsound "draft" graph early in the export, because even an incorrect graph is useful for kicking the tires on the downstream processing that happens after export. A draft export gives you a graph, and it also gives you a report describing what potential problems need to be fixed to get some guarantees about the correctness of the export. You can then chip away on the problems in the report until everything is green. One of the biggest innovations of draft-mode export is pervasive use of real tensor propagation when doing export: you run the export with actual tensors, so you can always trace through code, even if it is doing spicy things like data-dependent control flow.
Libtorch-free AOTInductor. AOTInductor generated binaries have a relatively small ABI surface that needs to be implemented. This hack from the most recent CUDA Mode meetup shows that you can just create an alternate implementation of the ABI that has no dependence on libtorch. This makes your deployed binary size much smaller!
Support for bundling CUDA kernels into AOTInductor. AOTInductor already supports directly bundling Triton kernels into the generated binary, but traditional CUDA kernels cannot be bundled in this way. There's no reason this has to be the case though: all we're doing is bundling cubins in both case. If we have the ability to bundle traditional CUDA kernels into AOTInductor, this means you could potentially directly embed custom operators into AOTInductor binaries, which is nice because then those operators no longer have to be offered on the runtime (especially if you're commonly iterating on these kernels!)
Export multigraphs. Export's standard model is to give you a single graph that you call unconditionally. But it's easy to imagine a level of indirection on top of these graphs, where we can dispatch between multiple graphs depending on some arguments to the model. For example, if you have a model that optionally takes an extra Tensor argument, you can simply have two graphs, one for when the Tensor is absent, and one for when it is present.
ABI stable PyTorch extensions. It's hard work being a third-party PyTorch extension with native code, because whenever there's a new release of Python or PyTorch you have to rebuild all of your wheels. If there was a limited ABI that you could build your extension against that didn't expose CPython and only relied on a small, stable ABI of PyTorch functions, your binary packaging situation would be much simpler! And if an extension relied on a small ABI, it could even be bundled with AOTInductor binary, letting these export products be truly package agnostic (one of our lessons we learned with torch.package is picking the split between "what is packaged" and "what is not" is very difficult, and people would much rather just have everything be packaged.) Jane Xu is investigating how to do this, and separately, Scott Wolchok has been refactoring headers in libtorch so that a small set of headers can be used independently of the rest of libtorch.
by Edward Z. Yang at January 09, 2025 08:50 PM
Previously, I discussed the value proposition of torch.compile. While doing so, I observed a number of downsides (long compile time, complicated operational model, lack of packaging) that were intrinsic to torch.compile's API contract, which emphasized being able to work on Python code as is, with minimal intervention from users. torch.export occupies a different spot in the tradeoff space: in exchange for more upfront work making a model exportable, it allows for use of PyTorch models in environments where using torch.compile as is would be impossible.
Scenario: Like before, suppose you want to deploy your model for inference. However, now you have more stringent runtime requirements: perhaps you need to do inference from a CPython-less environment (because your QPS requirements require GIL-less multithreading; alternately, CPython execution overhead is unacceptable but you cannot use CUDA graphs, e.g., due to CPU inference or dynamic shapes requirements). Or perhaps your production environment requires hermetic deploy artifacts (for example, in a monorepo setup, where infrastructure code must be continually pushed but model code should be frozen). But like before, you would prefer not to have to rewrite your model; you would like the existing model to serve as the basis for your Python-less inference binary.
What to do: Use torch.export targeting AOTInductor. This will compile the model into a self-contained shared library which then can be directly invoked from a C++ runtime. This shared library contains all of the compiler generated Triton kernels as precompiled cubins and is guaranteed not to need any runtime compilation; furthermore, it relies only on a small runtime ABI (with no CPython dependency), so the binaries can be used across versions of libtorch. AOTInductor's multithreading capability and low runtime overhead also makes it a good match for CPU inference too!
You don't have to go straight to C++ CPU/GPU inference: you can start with using torch.compile on your code before investing in torch.export. There are four primary extra requirements export imposes: (1) your model must compile with fullgraph=True (though you can sometimes bypass missing Dynamo functionality by using non-strict export; sometimes, it is easier to do non-strict torch.export than it is to torch.compile!), (2) your model's inputs/outputs must only be in torch.export's supported set of argument types (think Tensors in pytrees), (3) your model must never recompile--specifically, you must specify what inputs have dynamic shapes, and (4) the top-level of your model must be an nn.Module (so that export can keep track of all of the parameters your model has).
Some tips:
Open source examples: Among other things, torchchat has an example end-to-end AOTInductor setup for server-side LLM inference, which you can view in run.cpp.
torch.export specific downsides:
AOTInductor specific downsides:
Scenario: You need to deploy your PyTorch model to edge devices (e.g., a mobile phone or a wearable device) where computational resources are limited. You have requirements that are a bit different from server size: you care a lot more about minimizing binary size and startup time. Traditional PyTorch deployment with full libtorch won't work. The device you're deploying too might also have some strange extra processors, like a DSP or NPU, that you want your model to target.
What to do: Use torch.export targeting Executorch. Among other things, Executorch offers a completely separate runtime for exported PyTorch programs (i.e., it has no dependency on libtorch, except perhaps there are a few headers which we share between the projects) which was specifically designed for edge deployment. (Historical note: we spent a long time trying to directly ship a stripped down version of libtorch to mobile devices, but it turns out it's really hard to write code that is portable on server and client, so it's better to only share when absolutely necessary.) Quantization is also a pretty important part of deployment to Edge, and Executorch incorporates this into the end-to-end workflow.
Open source examples: torchchat also has an Executorch integration letting you run an LLM on your Android phone.
Downsides. All of the export related downsides described previously apply here. But here's something to know specifically about Executorch:
Scenario: You need a new function or self-contained module with an efficient kernel implementation. However, you would prefer not to have to write the CUDA (or even Triton) by hand; the kernel is something that torch.compile can generate from higher level PyTorch implementation. At the same time, however, you cannot tolerate just-in-time compilation at all (perhaps you are doing a massive training job, and any startup latency makes it more likely that one of your nodes will fail during startup and then you make no progress at all; or maybe you just find it annoying when PyTorch goes out to lunch when you cache miss).
What to do: Use torch.export targeting AOTInductor, and then load and run the AOTInductor generated binary from Python.
Downsides. So, we know this use case works, because we have internally used this to unblock people who wanted to use Triton kernels but could not tolerate Triton's just-in-time compilation. But there's not much affordance in our APIs for this use case; for example, guard-based dispatch is often quite useful for compiled functions, but you'll have to roll that by hand. More generally, when compiling a kernel, you have to make tradeoffs about how static versus dynamic the kernel should be (for example, will you force the inputs to be evenly divisible by eight? Or would you have a separate kernel for the divisible and not divisible cases?) Once again, you're on your own for making the call there.
Scenario: In an ideal world, you would have a model, you could export it to an AOTInductor binary, and then be all done. In reality, maybe this export process needs to be a multi-stage process, where it has to be processed to some degree on one machine, and then finish processing on another machine. Or perhaps you need to shift the processing over time: you want to export a model to freeze it (so it is no longer tied to its original source code), and then repeatedly run the rest of the model processing pipeline on this exported program (e.g., because you are continuously updating its weights and then reprocessing the model). Maybe you want to export the model and then train it from Python later, committing to a distributed training strategy only when you know how many nodes you are running. The ability to hermetically package a model and then process it later is one of the big value propositions of TorchScript and torch.package.
What to do: Use torch.export by itself, potentially using pre-dispatch if you need to support training use-cases. torch.export produces an ExportedProgram which has a clean intermediate representation that you can do processing on, or just serialize and then do processing on later.
Downsides:
Next time: What's missing, and what we're doing about it
by Edward Z. Yang at December 24, 2024 04:28 AM
Like free stuff? If you’re an MIT student, did you know you can get a free virtual machine, courtesy of SIPB’s XVM project? (And if you’re not an MIT student, well sorry, this guide probably won’t be very useful then.) That’s right, your own tiny VM, for free! Mandatory disclaimer: the VMs are indeed… very tiny, with only one core and 512 MiB of RAM, and XVM has some quirks. But you can still do a lot of cool stuff with it!
by Anthony Wang at December 11, 2024 04:45 PM
On the surface, the value proposition of torch.compile is simple: compile your PyTorch model and it runs X% faster. But after having spent a lot of time helping users from all walks of life use torch.compile, I have found that actually understanding how this value proposition applies to your situation can be quite subtle! In this post, I want to walk through the ways to use torch.compile, and within these use cases, what works and what doesn't. By the way, some of these gaps are either served by export, or by missing features we are actively working on, those will be some other posts!
Scenario: You have a model in PyTorch that you want to train at a small-medium scale (e.g., below 1K GPUs--at the 1K point there is a phase change in behavior that deserves its own section). You would like it to train faster. Locally, it's nice to get a trained model faster than you would have otherwise. But globally, the faster everyone's models train, the less GPU hours they use, which means you can run more jobs in a given time window with a fixed cluster. If your supply of GPUs is inelastic (lol), efficiency improvement means you can support more teams and use cases for the same amount of available GPUs. At a capacity planning level, this can be a pretty big deal even if you are GPU rich.
What to do: In some sense, this is the reason we built torch.compile. (When we were initially planning torch.compile, we were trying to assess if we were going after inference; but inference compilers are a much more crowded space than training compilers, and we reasoned that if we did a good job building a training compiler, inference would work too--which it did!) The dream which we sold with torch.compile is that you could slap it on the top of your model and get a speed up. This turns out to... not quite be true? But the fact remains that if you're willing to put in some work, there is almost always performance waiting at the end of the road for you. Some tips:
Open source examples: torchtune and torchtitan are two first party libraries which are intended to showcase modern PyTorch using torch.compile in a training context. There's also some training in torchao.
Downsides:
Scenario: You've finished training your model and you want to deploy it for inference. Here, you want to improve the efficiency of inference to improve response latency or reduce the overall resource requirements of the system, so you can use less GPUs to serve the traffic you are receiving. Admittedly, it is fairly common to just use some other, more inference friendly systems (which I will decline to name by name lol) to serve the model. But let's say you can't rewrite the model in a more serving friendly language (e.g., because the model authors are researchers and they keep changing the model, or there's a firehose of models and you don't have the money to keep continuously porting each of them, or you depend on an ecosystem of libraries that are only available in CPython).
What to do: If Python can keep up with the CPU-side QPS requirements, a way of getting good performance without very much work is taking the Python model, applying torch.compile on it in the same way as you did in training and directly using this as your inference solution. Some tips that go beyond training:
Open source examples: LLM serving on torch.compile is quite popular: vllm, sglang, tensorrt-llm, gpt-fast (this is technically not an E2E serving solution, but one of its primary reasons for existing is to serve as a starting point so you can build your own torch.compile based LLM inference stack on top of it). Stable diffusion models are also notable beneficiaries of torch.compile, e.g., diffusers.
Downsides:
Scenario: In both the cases above, we assumed that we had a preexisting eager model that worked, and we just wanted to make it faster. But you can also use the compiler in a load bearing way, where the model does not work without the compiler. Here are two common cases where this can occur:
What to do: Unlike in the previous cases where you took a preexisting model and slap torch.compile, this sort of use of the compiler is more likely to arise from a codevelopment approach, where you use torch.compile while you build your model, and are constantly checking what the compiler does to the code you write. Some tips:
Open source examples. SimpleFSDP as mentioned above. VLLM uses torch.compile to apply custom optimization passes. Although its implementation is considerably more involved than what you might reasonable expect a third party to implement, FlexAttention is a good example of a non-compiler feature that relies on the compiler in a load-bearing way for performance.
Downsides: Beyond the ones mentioned above:
Next time: Ways to use torch.export
by Edward Z. Yang at November 05, 2024 03:11 PM
Tensor libraries like PyTorch and JAX have developed compact and accelerated APIs for manipulating n-dimensional arrays. N-dimensional arrays are kind of similar to tables in database, and this results in the logical question which is could you setup a Tensor-like API to do queries on databases that would be normally done with SQL? We have two challenges:
However, we have a secret weapon: first class dimensions were primarily designed to as a new frontend syntax that made it easy to express einsum, batching and tensor indexing expressions. They might be good for SQL too.
Representing the database. First, how do we represent a database? A simple model following columnar database is to have every column be a distinct 1D tensor, where all columns part of the same table have a consistent indexing scheme. For simplicity, we'll assume that we support rich dtypes for the tensors (e.g., so I can have a tensor of strings). So if we consider our classic customer database of (id, name, email), we would represent this as:
customers_id: int64[C] customers_name: str[C] customers_email: str[C]
Where C is the number of the entries in the customer database. Our tensor type is written as dtype[DIM0, DIM1, ...], where I reuse the name that I will use for the first class dimension that represents it. Let's suppose that the index into C does not coincide with id (which is good, because if they did coincide, you would have a very bad time if you ever wanted to delete an entry from the database!)
This gives us an opportunity for baby's first query: let's implement this query:
SELECT c.name, c.email FROM customers c WHERE c.id = 1000
Notice that the result of this operation is data-dependent: it may be zero or one depending on if the id is in the database. Here is a naive implementation in standard PyTorch:
mask = customers_id == 1000 return (customers_name[mask], customers_email[mask])
Here, we use boolean masking to perform the data-dependent filtering operation. This implementation in eager is a bit inefficient; we materialize a full boolean mask that is then fed into the subsequent operations; you would prefer for a compiler to fuse the masking and indexing together. First class dimensions don't really help with this example, but we need to introduce some new extensions to first class dimensions. First, what we can do:
C = dims(1)
c_id = customers_id[C] # {C} => int64[]
c_name = customers_name[C] # {C} => str[]
c_email = customers_email[C] # {C} => str[]
c_mask = c_id == 1000 # {C} => bool[]
Here, a tensor with first class tensors has a more complicated type {DIM0, DIM1, ...} => dtype[DIM2, DIM3, ...]. The first class dimensions are all reported in the curly braces to the left of the double arrow; curly braces are used to emphasize the fact that first class dimensions are unordered.
What next? The problem is that now we want to do something like torch.where(c_mask, c_name, ???) but we are now in a bit of trouble, because we don't want anything in the false branch of where: we want to provide something like "null" and collapse the tensor to a smaller number of elements, much like how boolean masking did it without first class dimensions. To express this, we'll introduce a binary version of torch.where that does exactly this, as well as returning the newly allocated FCD for the new, data-dependent dimension:
C2, c2_name = torch.where(c_mask, c_name) # {C2} => str[]
_C2, c2_email = torch.where(c_mask, c_email) # {C2} => str[], n.b. C2 == _C2
return c2_name, c2_email
Notice that torch.where introduces a new first-class dimension. I've chosen that this FCD gets memoized with c_mask, so whenever we do more torch.where invocations we still get consistently the same new FCD.
Having to type out all the columns can be a bit tiresome. If we assume all elements in a table have the same dtype (let's call it dyn, short for dynamic type), we can more compactly represent the table as a 2D tensor, where the first dimension is the indexing as before, and the second dimension is the columns of the database. For clarity, we'll support using the string name of the column as a shorthand for the numeric index of the column. If the tensor is contiguous, this gives a more traditional row-wise database. The new database can be conveniently manipulated with FCDs, as we can handle all of the columns at once instead of typing them out individually):
customers: dyn[C, C_ATTR]
C = dims(1)
c = customers[C] # {C} => dyn[C_ATTR]
C2, c2 = torch.where(c["id"] == 1000, c) # {C2} => dyn[C_ATTR]
return c2[["name", "email"]].order(C2) # dyn[C2, ["name", "email"]]
We'll use this for the rest of the post, but the examples should be interconvertible.
Aggregation. What's the average age of all customers, grouped by the country they live in?
SELECT AVG(c.age) FROM customers c GROUP BY c.country;
PyTorch doesn't natively support this grouping operation, but essentially what is desired here is a conversion into a nested tensor, where the jagged dimension is the country (each of which will have a varying number of countries). Let's hallucinate a torch.groupby analogous to its Pandas equivalent:
customers: dyn[C, C_ATTR]
customers_by_country = torch.groupby(customers, "country") # dyn[COUNTRY, JC, C_ATTR]
COUNTRY, JC = dims(2)
c = customers_by_country[COUNTRY, JC] # {COUNTRY, JC} => dyn[C_ATTR]
return c["age"].mean(JC).order(COUNTRY) # f32[COUNTRY]
Here, I gave the generic indexing dimension the name JC, to emphasize that it is a jagged dimension. But everything proceeds like we expect: after we've grouped the tensor and rebound its first class dimensions, we can take the field of interest and explicitly specify a reduction on the dimension we care about.
In SQL, aggregations have to operate over the entirety of groups specified by GROUP BY. However, because FCDs explicitly specify what dimensions we are reducing over, we can potentially decompose a reduction into a series of successive reductions on different columns, without having to specify subqueries to progressively perform the reductions we are interested in.
Joins. Given an order table, join it with the customer referenced by the customer id:
SELECT o.id, c.name, c.email FROM orders o JOIN customers c ON o.customer_id = c.id
First class dimensions are great at doing outer products (although, like with filtering, it will expensively materialize the entire outer product naively!)
customers: dyn[C, C_ATTR]
orders: dyn[O, O_ATTR]
C, O = dims(2)
c = customers[C] # {C} => dyn[C_ATTR]
o = orders[O] # {O} => dyn[O_ATTR]
mask = o["customer_id"] == c["id"] # {C, O} => bool[]
outer_product = torch.cat(o[["id"]], c[["name", "email"]]) # {C, O} => dyn[["id", "name", "email"]]
CO, co = torch.where(mask, outer_product) # {CO} => dyn[["id", "name", "email"]]
return co.order(CO) # dyn[C0, ["id", "name", "email"]]
What's the point. There are a few reasons why we might be interested in the correspondence here. First, we might be interested in applying SQL ideas to the Tensor world: a lot of things people want to do in preprocessing are similar to what you do in traditional relational databases, and SQL can teach us what optimizations and what use cases we should think about. Second, we might be interested in applying Tensor ideas to the SQL world: in particular, I think first class dimensions are a really intuitive frontend for SQL which can be implemented entirely embedded in Python without necessitating the creation of a dedicated DSL. Also, this might be the push needed to get TensorDict into core.
by Edward Z. Yang at October 14, 2024 05:07 AM
One of the things that I learned in grad school is that even if you've picked an important and unsolved problem, you need some reason to believe it is solvable--especially if people have tried to solve it before! In other words, "What's different this time?" This is perhaps a dreary way of shooting down otherwise promising research directions, but you can flip it around: when the world changes, you can ask, "What can I do now that I couldn't do before?"
This post is a list of problems in areas that I care about (half of this is PL flavor, since that's what I did my PhD in), where I suspect something has changed with the advent of LLMs. It's not a list of recipes; there is still hard work to figure out how exactly an LLM can be useful (for most of these, just feeding the entire problem into ChatGPT usually doesn't work). But I often talk to people want to get started on something, anything, but have no idea to start. Try here!
Static analysis. The chasm between academic static analysis work and real world practice is the scaling problems that come with trying to apply the technique to a full size codebase. Asymptotics strike as LOC goes up, language focused techniques flounder in polyglot codebases, and "Does anyone know how to write cmake?" But this is predicated on the idea that static analysis has to operate on a whole program. It doesn't; humans can do perfectly good static analysis on fragments of code without having to hold the entire codebase in their head, without needing access to a build system. They make assumptions about APIs and can do local reasoning. LLMs can play a key role in drafting these assumptions so that local reasoning can occur. What if the LLM gets it wrong? Well, if an LLM could get it wrong, an inattentive junior developer might get it wrong too--maybe there is a problem in the API design. LLMs already do surprisingly well if you one-shot prompt them to find bugs in code; with more traditional static analysis support, maybe they can do even better.
DSL purgatory. Consider a problem that can be solved with code in a procedural way, but only by writing lots of tedious, error prone boilerplate (some examples: drawing diagrams, writing GUIs, SQL queries, building visualizations, scripting website/mobile app interactions, end to end testing). The PL dream is to design a sweet compositional DSL that raises the level of abstraction so that you can render a Hilbert curve in seven lines of code. But history is also abound with cases where the DSL did not solve the problems, or maybe it did solve the problem but only after years of grueling work, and so there are still many problems that feel like there ought to be a DSL that should solve them but there isn't. The promise of LLMs is that they are extremely good at regurgitating low level procedural actions that could conceivably be put together in a DSL. A lot of the best successes of LLMs today is putting coding powers in the hands of domain experts that otherwise do not how to code; could it also help in putting domain expertise in the hands of people who can code?
I am especially interested in these domains:
OSS bread and butter. Why is Tesseract still the number one OSS library for OCR? Why is smooth and beautiful text to voice not ubiquitous? Why is the voice control on my Tesla so bad? Why is the wake word on my Android device so unreliable? Why doesn't the screenshot parser on a fansite for my favorite mobage not able to parse out icons? The future has arrived, but it is not uniformly distributed.
Improving the pipeline from ephemeral to durable stores of knowledge. Many important sources of knowledge are trapped in "ephemeral" stores, like Discord servers, private chat conversations, Reddit posts, Twitter threads, blog posts, etc. In an ideal world, there would be a pipeline of this knowledge into more durable, indexable forms for the benefit of all, but actually doing this is time consuming. Can LLMs help? Note that the dream of LLMs is you can just feed all of this data into the model and just ask questions to it. I'm OK with something a little bit more manual, we don't have to solve RAG first.
by Edward Z. Yang at October 04, 2024 04:30 AM
Suppose we have a large collection of documents, and we wish you identify which documents are approximately the same as each other. For instance, we may have crawled the web over some period of time, and expect to have fetched the “same page” several times, but to see slight differences in metadata, or that we have several revisions of a page following small edits. In this post I want to explore the method of approximate deduplication via Jaccard similarity and the MinHash approximation trick.
I worked at Stripe for about seven years, from 2012 to 2019. Over that time, I used and contributed to many generations of Stripe’s developer environment – the tools that engineers used daily to write and test code. I think Stripe did a pretty good job designing and building that developer experience, and since leaving, I’ve found myself repeatedly describing features of that environment to friends and colleagues. This post is an attempt to record the salient features of that environment as I remember it.
I was recently introduced to the paper “Seeing the Invisible: Perceptual-Cognitive Aspects of Expertise” by Gary Klein and Robert Hoffman. It’s excellent and I recommend you read it when you have a chance. Klein and Hoffman discuss the ability of experts to “see what is not there”: in addition to observing data and cues that are present in the environment, experts perceive implications of these cues, such as the absence of expected or “typical” information, the typicality or atypicality of observed data, and likely/possible past and future time trajectories of a system based on a point-in-time snapshot or limited duration of observation.
This December, the imp of the perverse struck me, and I decided to see how many days of Advent of Code I could do purely in compile-time C++ metaprogramming. As of this writing, I’ve done two days, and I’m not sure I’ll make it any further. However, that’s one more day than I planned to do as of yesterday, which is in turn further than I thought I’d make it after my first attempt.
I have spent many years as an software engineer who was a total outsider to machine-learning, but with some curiosity and occasional peripheral interactions with it. During this time, a recurring theme for me was horror (and, to be honest, disdain) every time I encountered the widespread usage of Python pickle in the Python ML ecosystem. In addition to their major security issues1, the use of pickle for serialization tends to be very brittle, leading to all kinds of nightmares as you evolve your code and upgrade libraries and Python versions.
Suppose we’ve got a service. We’ll gloss over the details for now, but let’s stipulate that it accepts requests from the outside world, and takes some action in response. Maybe those requests are HTTP requests, or RPCs, or just incoming packets to be routed at the network layer. We can get more specific later. What can we say about its performance? All we know is that it receives requests, and that it acts on them.
What’s the “right” level of CPU utilization for a server? If you look at a monitoring dashboard from a well-designed and well-run service, what CPU utilization should we hope to see, averaged over a day or two? It’s a very general question, and it’s not clear it should have a single answer. That said, for a long time, I generally believed that higher is always better: we should aim for as close to 100% utilization as we can.
Ever since its introduction in the 2017 paper, Attention is All You Need, the Transformer model architecture has taken the deep-learning world by storm. Initially introduced for machine translation, it has become the tool of choice for a wide range of domains, including text, audio, video, and others. Transformers have also driven most of the massive increases in model scale and capability in the last few years. OpenAI’s GPT-3 and Codex models are Transformers, as are DeepMind’s Gopher models and many others.
In my day job at Anthropic, we run relatively large distributed systems to train large language models. One of the joys of using a lot of computing resources, especially on somewhat niche software stacks, is that you spend a lot of time running into the long-tail of bugs which only happen rarely or in very unusual configurations, which you happen to be the first to encounter. These bugs are frustrating, but I also often enjoy them.
One of the annoying things about scraping websites is bouncing back and forth between the browser where you are using Dev Tools to work out what selectors you should be using to scrape out data, and your actual scraping script, which is usually some batch program that may have to take a few steps before the step you are debugging. A batch script is fine once your scraper is up and running, but while developing, it's really handy to pause the scraping process at some page and fiddle around with the DOM to see what to do.
This interactive-style development is exactly what Juypter notebooks shine at; when used in conjunction with a browser-based scraping library like Puppeteer, you can have exactly this workflow. Here's the setup:
There will be a live browser instance which you can poke at using Dev Tools, and you type commands into the Jupyter notebook and see how they affect the browser state.
I tweeted about this and the commenters had some good suggestions about other things you could try:
by Edward Z. Yang at November 23, 2021 02:28 PM
CPU cycles are cheaper than they have ever been, and cloud computing has never been more ubiquitous. All the major cloud providers offer generous free tiers, and services like GitHub Actions offer free compute resources to open-source repositories. So why do so many developers still build software on their laptops? Despite the embarrassment of riches of cheap or even free cloud compute, most projects I know of, and most developers, still do most of their software development — building and running code — directly on their local machines.
Last week, Frederic Cambus wrote about building LLVM quickly on some very large machines, culminating in a 2m37s build on a 160-core ARM machine. I don’t have a giant ARM behemoth, but I have been working on a tool I call Llama, which lets you offload computational work – including C and C++ builds – onto Amazon Lambda. I decided to see how good it could do at a similar build.
I'm launching a new podcast, the PyTorch Developer Podcast. The idea is to be a place for the PyTorch dev team to do bite sized (10-20 min) topics about all sorts of internal development topics in PyTorch. For now, it's just me monologuing for fifteen minutes about whatever topic I decide. The plan is to release an episode daily, five days a week, until I run out of things to say (probably not for a while, I have SO MANY THINGS TO SAY). I don't edit the podcasts and do minimal planning, so they're a bit easier to do than blog posts. Check it out! There's two episodes out already, one about how we do Python bindings for our C++ objects and another about history and constraints of the dispatcher. If there are any topics you'd like me to cover, give a shout.
by Edward Z. Yang at May 05, 2021 03:26 PM
At Facebook, we have an internal convention for tooling called "rage". When something goes wrong and you want to report a bug, the tool developer will typically ask you to give them a rage. For a command line tool, this can be done by running a rage subcommand, which will ask about which previous CLI invocation you'd like to report, and then giving you a bundle of logs to send to the developer.
A rage has an important property, compared to a conventional log level flag like -v: rage recording is always on. In other words, it is like traditional server application logs, but applied to client software. Logging is always turned on, and the rage subcommand makes it easy for a user to send only the relevant portion of logs (e.g., the logs associated with the command line invocation that is on).
For some reason, rage functionality is not that common in open source tools. I can imagine any number of reasons why this might be the case:
Still, in the same way most sysadmins view logging as an invaluable tool for debugging server issues, I think rage reporting is an invaluable tool for debugging client issues. In ghstack, it didn't take very many lines of code to implement rage reporting: ghstack.logs (for writing the logs to the rage directory) and ghstack.rage (for reading it out). But it has greatly reduced my support load for the project; given a rage, I can typically figure out the root cause of a bug without setting up a reproducer first.
by Edward Z. Yang at April 25, 2021 04:03 AM
People who work with me tend to realize that I have Opinions about databases, and SQL databases in particular. Last week, I wrote about a Postgres debugging story and tweeted about AWS’ policy ban on internal use of SQL databases, and had occasion to discuss and debate some of those feelings on Twitter; this article is an attempt to write up more of them into a single place I can refer to.
PyTorch is a fairly large and active open source project, and sometimes we have people come to us and ask if there are any lessons from how we run PyTorch that they could apply to their own projects. This post is an attempt to describe some of the processes as of 2021 that help PyTorch operate effectively as an open source project. I won't claim that everything we do necessarily the best way to go about doing things, but at the very least, everything I describe here is working in practice.
Background. Not all open source projects are the same, and there are some peculiarities to PyTorch which may reduce the applicability of some of what I describe below in other contexts. Here are some defining features of PyTorch, as a project:
Alright, so how does PyTorch deal with its scale? Here are some of the things we do.
Issue triage. PyTorch receives too many bug reports a day for any one person to keep track of all of them. Largely inspired by this apenwarr post, we setup an oncall rotation amongst Facebook contributors to serve as first line triage for all of these issues. The golden rule of issue triage is that you DO NOT fix bugs in triage; the goal of triage is to (1) route bugs to the correct people via appropriate GitHub labels, and (2) look for high priority bugs and raise awareness of these bugs. Every week, we have a meeting to review high priority bugs (and other bugs marked for triage review) and talk about them. The oncall itself rotates daily, to discourage people from letting a week's worth of issues pile up in the backlog, and we use a relatively intricate search query to make sure only relevant issues show up for the oncall to handle.
The most important consequence of issue triage is that you can unwatch PyTorch repository as a whole. Instead, by watching various labels (using our cc bot), you can trust that you will get CC'ed to issues related to topics, even if the triager doesn't know that you're interested in the issue! The weekly meeting makes sure that all maintainers collectively have an idea about what major issues are currently affecting PyTorch, and helps socialize what we as a project think of as a "high priority" issue. Finally, the high priority label is a good way to find impactful problems to work on in the project, even if you don't know much else about the project.
Pull request triage. Similarly, we receive a decent number of drive by pull requests from one time contributors. Those people are not in a good position to find reviewers for their contributions, so we also have a triager look through these pull requests and make sure someone is assigned to review them. If the PR is particularly simple, the triager might just go ahead and merge it themselves. There's actually some good automation for doing this (e.g., homu) but we've been too lazy to set any of it up, and by hand reviewer assignment doesn't seem to be too much burden on top of the existing oncall.
Tree hugging oncall. PyTorch has a huge CI system covering many different system configurations which most contributors rely on to test if their changes are safe. Sometimes people break master. Separate from the triage oncall, we have a tree hugging oncall whose job it is to revert jobs if they break master. This oncall involves mostly paying attention to the CI HUD and reverting commits if they result in master breakage in one of the configurations.
Importing to Facebook infrastructure. We actually run Facebook infrastructure directly off of the HEAD branch in PyTorch. The tooling that makes this possible is fbshipit, which mirrors commits between Facebook's internal monorepo and our public GitHub repository. This setup has been something of a double-edged sword for us: requiring Facebook and GitHub to be in sync means that only Facebook employees can actually land pull requests (we try to streamline the process as much as possible for external maintainers, but at the end of the day someone at Facebook has to actually push the green button), but it means we don't have to worry about doing periodic "mega-imports" into Facebook infrastructure (which we have done in the past and were quite difficult to do). We are very interested in fixing this situation and have floated some proposals on changing how we do internal releases to make it possible to let external contributors land PRs directly.
RFCs. Most feature discussion happens on GitHub issues, but sometimes, a feature is too big and complicated to adequately discuss in a GitHub issue. In those cases, they can be discussed in the rfcs repository (inspired by the Rust RFCs process). The formal process on this repository isn't too solidified yet, but generally people go there if they feel that it is too difficult to discuss the issue in GitHub issues. We don't yet have a process for shepherding unsolicited RFCs.
Conclusion. PyTorch's open source process isn't rocket science: there's an oncall, the oncall does some things. The devil is in the details: all of PyTorch's oncall responsibilities are carefully scoped so that your oncall responsibilities aren't something that will take an unbounded amount of time; they're something you can knock out in an hour or two and call it a day. You could make the argument that we rely excessively on oncalls when automation is possible, but what we have found is that oncalls require less infrastructure investment, and integrate well with existing processes and flows at Facebook. They might not be right everywhere, but at least for us they seem to be doing a good job.
by Edward Z. Yang at January 06, 2021 04:56 PM
Years ago, Nadav Rotem related to me this story about why basic block procedures in Swift are not as good as they seem. Nelson Elhage reminded me about this on Twitter and so I thought this should be put into the public record.
Basic block procedures make certain optimizations more difficult. Consider this program:
block j3 (%y1, %y2) { ... }
block j1 () { jump j3(%x1, %x2) }
block j2 () { jump j3(%x3, %x4) }
Is this program easier or more difficult to optimize than the traditional SSA with phi-nodes formulation?
L1: goto L3 L2: goto L3 L3: %y1 = phi [%x1, %L1] [%x3, %L2] %y2 = phi [%x2, %L1] [%x4, %L2]
Suppose that the optimizer determines that y1 is unused inside j3/L3 and can be eliminated. In basic block land, y1 can be eliminated simply by deleting "y1 = phi x1 x3". However, in join point land, you have to not only eliminate y1 but also update all the call sites of j3, since you've changed the function signature. In a mutable AST, changing function signatures is a pain; in particular, the mutations you would have to do to eliminate the argument include intermediate states that are not valid ASTs (making it easy to accidentally trigger asserts.)
When I saw this example, I wondered why GHC (which has the moral equivalent of basic block procedures in the form of join points) didn't have this problem. Well, it turns out this optimization can be done as a series of local transformations. First, we do a worker/wrapper transformation, introducing an intermediate block (the worker) that drops the dead argument:
block j3 (%y1, %y2) { jump wj3(%y2) }
block j1 () { jump j3(%x1, %x2) }
block j2 () { jump j3(%x3, %x4) }
block wj3 (%y2) { ... }
Later, we inline j3, which removes the wrapper. Worker/wrapper is a very important optimization for functional programs, but it's easy to imagine why it is less preferred in mutable compiler land.
by Edward Z. Yang at October 24, 2020 11:34 PM
Greetings from 2024! An official pattern matching PEP has been accepted https://peps.python.org/pep-0636/ and is available in Python 3.10. Class patterns are tested using isinstance, with no inheritance structure necessary, making the pattern described in this post 100% forward compatible to real pattern matching.
One of the features I miss most in non-Haskell programming languages is algebraic data types (ADT). ADTs fulfill a similar role to objects in other languages, but with more restrictions: objects are an open universe, where clients can implement new subclasses that were not known at definition time; ADTs are a closed universe, where the definition of an ADT specifies precisely all the cases that are possible. We often think of restrictions of a bad thing, but in the case of ADTs, the restriction of being a closed universe makes programs easier to understand (a fixed set of cases to understand, as opposed to a potentially infinite set of cases) and allows for new modes of expression (pattern matching). ADTs make it really easy to accurately model your data structures; they encourage you to go for precise types that make illegal states unrepresentable. Still, it is generally not a good idea to try to manually reimplement your favorite Haskell language feature in every other programming language you use, and so for years I've suffered in Python under the impression that ADTs were a no go.
Recently, however, I have noticed that a number of new features in Python 3 have made it possible to use objects in the same style of ADTs, in idiomatic Python with virtually no boilerplate. The key features:
The key idea: define each constructor as a dataclass, put the constructors together into an ADT using a Union type, and use isinstance tests to do pattern matching on the result. The result is just as good as an ADT (or better, perhaps; their structural nature bears more similarity to OCaml's polymorphic variants).
Here's how it works. Let's suppose that you want to define an algebraic data type with two results:
data Result = OK Int | Failure String showResult :: Result -> String showResult (OK result) = show result showResult (Failure msg) = "Failure: " ++ msg
First, we define each constructor as a dataclass:
from dataclasses import dataclass
@dataclass(frozen=True)
class OK:
result: int
@dataclass(frozen=True)
class Failure:
msg: str
Using the automatically generated constructors from dataclasses, we can construct values of these dataclasses using OK(2) or Failure("something wrong"). Next, we define a type synonym for the union of these two classes:
Result = Union[OK, Failure]
Finally, we can do pattern matching on Result by doing isinstance tests:
def assert_never(x: NoReturn) -> NoReturn:
raise AssertionError("Unhandled type: {}".format(type(x).__name__))
def showResult(r: Result) -> str:
if isinstance(r, OK):
return str(r.result)
elif isinstance(r, Failure):
return "Failure: " + r.msg
else:
assert_never(r)
assert_never is a well known trick for doing exhaustiveness checking in mypy. If we haven't covered all cases with enough isinstance checks, mypy will complain that assert_never was given a type like UnhandledCtor when it expected NoReturn (which is the uninhabited type in Python).
That's all there is to it. As an extra bonus, this style of writing unions is compatible with the structured pattern matching PEP, if it actually gets accepted. I've been using this pattern to good effect in our recent rewrite of PyTorch's code generator. If you have the opportunity to work in a statically typed Python codebase, give this style of code a try!
by Edward Z. Yang at October 14, 2020 06:08 PM

If this is your first time reading about PyTorch internals, you might want to check out my PyTorch internals post first. In this post, I want to talk about one particular part of PyTorch's internals: the dispatcher. At a first glance, the dispatcher is just a glorified if statement: based on some information about the tensor inputs, decide what piece of code should be called. So why should we care about the dispatcher?

Well, in PyTorch, a lot of things go into making an operator work. There is the kernel that does the actual work, of course; but then there is support for reverse mode automatic differentiation, e.g., the bits that make loss.backward() work. Oh, and if your code under torch.jit.trace, you can get a trace of all the operations that were run. Did I mention that if you run these operations on the inside of a vmap call, the batching behavior for the operators is different? There are so many different ways to interpret PyTorch operators differently, and if we tried to handle all of them inside a single function named add, our implementation code would quickly devolve into an unmaintainable mess. The dispatcher is not just an if statement: it is a really important abstraction for how we structure our code internally PyTorch... and it has to do so without degrading the performance of PyTorch (too much, anyway).

At the end of this post, our goal will be to understand all the different parts of this picture fit together. This post will proceed in three parts.


First, we'll talk about the dispatcher itself. What is the dispatcher, how does it decide what kernel to call? Second, we'll talk about the operator registration API, which is the interface by which we register kernels into the dispatcher. Finally, we'll talk about boxing and unboxing, which are a cross-cutting feature in the dispatcher that let you write code once, and then have it work on all kernels.

OK, so what is the dispatcher? For every operator, the dispatcher maintains a table of function pointers which provide implementations for each dispatch key, which corresponds roughly to one of the cross-cutting concerns in PyTorch. In the diagram above, you can see there are dispatch entries in this table for backends (CPU, CUDA, XLA) as well as higher-level concepts like autograd and tracing. The dispatcher's job is to compute a dispatch key, based on the input tensors and some other stuff (more on this shortly), and then do an indirect jump to the function pointed to by the table.
Those of you who are familiar with C++ may observe that this table of function pointers is very similar to virtual tables in C++. In C++, virtual methods on objects are implemented by associating every object with a pointer to a virtual table that contains implementations for each virtual method on the object in question. In PyTorch, we essentially reimplemented virtual tables, but with some differences:
Fun historical note: we used to use virtual methods to implement dynamic dispatch, and reimplemented them when we realized we needed more juice than virtual tables could give us.

So how exactly do we compute the dispatch key which we use to index into the dispatch table? The basic abstraction we use for computing what dispatch key to use is a dispatch key set, which is a bitset over dispatch keys. The general concept is that we union together dispatch key sets from various sources (and in some case mask out some dispatch keys), giving us a final dispatch key set. Then, we pick the first dispatch key in the set (dispatch keys are implicitly ordered by some priority) and that is where we should dispatch to. What are these sources?
There is also a local exclude set, which is used to exclude dispatch keys from dispatch. A common pattern is for some handler to handle a dispatch key, and then mask itself off via the local exclude set, so we don't try reprocessing this dispatch key later.
Let's walk through the evolution of dispatch key through some examples.

(Warning: This description is out-of-date for PyTorch master. Instead of Autograd being in global, it is instead on the Tensor. Everything else proceeds as before.)
The most canonical example of the dispatch machinery in operation is how it handles autograd. Read the diagram from the top to the bottom. At the very top, Autograd is in the global set, and the local exclude set is empty. When we do dispatch, we find autograd is the highest priority key (it's higher priority than CPU), and we dispatch to the autograd handler for the operator. Inside the autograd handler, we do some autograd stuff, but more importantly, we create the RAII guard AutoNonVariableTypeMode, which adds Autograd to the local exclude set, preventing autograd from being handled for all of the operations inside of this operator. When we redispatch, we now skip the autograd key (as it is excluded) and dispatch to the next dispatch key, CPU in this example. As local TLS is maintained for the rest of the call tree, all other subsequent dispatches also bypass autograd. Finally, in the end, we return from our function, and the RAII guard removes Autograd from the local exclude set so subsequent operator calls once again trigger autograd handlers.

Another similar example is tracing, which is similar to autograd where when we enter the tracing handler, we disable tracing for nested calls with ExcludeDispatchKeyGuard. However, it differs from autograd in how tracing is initially triggered: tracing is toggled by a dispatch key that is added to the local include set when you turn on tracing (with IncludeDispatchKeyGuard), as opposed to the global dispatch key from Autograd (Update: now a dispatch key on tensors).

One final example is the BackendSelect key, which operates a little differently from normal keys. The problem backend select solves is that sometimes, the default dispatch key set calculation algorithm doesn't know how to work out what the correct dispatch key should be. One notable case of this are factory functions, which don't have any Tensor arguments (and so, naively, would not dispatch to anything). BackendSelect is in the global dispatch key set, but is only registered for a few operators (for the rest, it is a fallthrough key). The BackendSelect handler inspects the arguments and decides what the final dispatch key should be, and then does a direct dispatch to that key, bypassing dispatch key calculation.

The slide summarizes some of the most common sequences of handlers that get processed when dispatching some operation in PyTorch. Most of the time, it's autograd, and then the backend (with a backend select in-between if you are a factory function). For XLA, there is also an XLAPreAutograd key (Update: This key is now simply AutogradXLA) which can be used to override the behavior of the Autograd key. And of course, if you turn on every feature in PyTorch all at once, you can end up stopping at a lot of handlers. Notice that the order in which these handlers are processed matters, since handlers aren't necessarily commutative.
So we talked a lot about how we decide what function pointers in the dispatch table to call, but how do these pointers get in the dispatch table in the first place? This is via the operator registration API. If you have never seen this API before, you should take a look at the Dispatcher in C++ tutorial, which describes how the API works at a very high level. In this section, we'll dive into more detail about how exactly the registration API maps to the dispatch table. Below, you can see the three main ways of interacting with the operator registration API: you define schemas for operators and then register implementations at dispatch keys; finally, there is a fallback method which you can use to define a handler for all operators at some dispatch key.

To visualize the impact of these registration operators, let us imagine that the dispatch tables for all operators collectively form a grid, like this:

On one axis, we have each operator supported in PyTorch. On the other axis, we have each dispatch key we support in our system. The act of operator registration involves filling in cells with implementations under these two axes.
When we register a kernel for a single operator at a specific dispatch key, we fill in a single cell (blue below):

When you register a kernel as a "catch-all" kernel for all dispatch keys in an operator, you fill in an entire row for the operator with one kernel (red below). By the way, if this seems like a strange thing to want to do, it is! And we're working to remove this capability in favor of more specific fills for a subset of keys.

When you register a kernel as a fallback for kernel for a single dispatch key, you fill in the column for that dispatch key (green).

There's a precedence to these registrations: exact kernel registrations have the highest precedence, and catch all kernels take precedence over fallback.

I want to spend the last part of this post talking about the boxing and unboxing facilities in our dispatcher, which turn out to be pretty important for enabling backend fallback. When you are a programming language designer, there is a classic tradeoff you have to make in deciding whether or not you want to use a boxed or unboxed representation for data:

A boxed or homogenous representation is a data representation where every type of object in your system has the same layout. Typically, this means you have some representation that has a header describing what the object in question is, and then some regular payload after it. Homogenous representations are easy to work with in code: because you can always assume that data has some regular layout, you can write functions that work polymorphically over any type of data (think of a function in Java that takes in an arbitrary Object, for example). Most garbage-collected languages have some boxed representation for heap objects, because the garbage collector needs to be able to work over any type of heap object.
In contrast, an unboxed or heterogenous representation allows objects to have a different layout depending on the data in question. This is more efficient than a homogenous representation, as each object can tailor its internal representation to exactly what is needed for the task at hand. However, the downside is we can no longer easily write a single function that works polymorphically over many types of objects. In C++, this problem is worked around using templates: if you need a function to work on multiple types, the C++ compiler will literally create a new copy of the function specialized to each type it is used with.

By default, C++ defaults heterogenous layout, but we have implemented homogenous layout in PyTorch by way of the IValue struct (short for interpreter value), which implements a boxed representation that we can use in our interpreter. An IValue is a two word structure consisting of a payload word (usually a pointer, but it could also be an integer or float directly packed into the field) and a tag word which tells us what kind of value the IValue is.
This means we have two calling conventions for functions in PyTorch: the usual, C++, unboxed convention, and a boxed convention using IValues on a stack. Calls (from end users) can come from unboxed API (direct C++ call) or boxed API (from the JIT interpreter); similarly, kernels can be implemented as direct C++ functions (unboxed convention), or can be implemented as a boxed fallback (which by necessity is boxed, as they are polymorphic over all operators).
If I call from boxed API to a boxed fallback, it's easy to see how to plug the two components together...

...but how do I get from the unboxed API to the boxed fallback?

We need some sort of adapter to take the unboxed inputs and turn them into IValues so that they can be passed via the boxed calling convention. This is done via a boxing adapter, which is automatically generated using C++ templates working off of the unboxed C++ types in the outward facing API.

There is also an inverse problem, which is what to do if we have inputs from an boxed API and need to call into an unboxed kernel. Similarly, we have an unboxing adapter, which performs this translation. Unlike the boxing adapter, this adapter is applied to the kernel itself, since C++ templates only work at sites where the unboxed type is statically available (at the boxed API site, these types are not known, so you literally cannot implement this.) Note that we always keep the unboxed API around, so that if a user calls in from the unboxed API, we can fastpath straight to the unboxed kernel.

So here is what boxing and unboxing looks overall:

Boxing and unboxing are a key feature in the implementation of boxed fallback: without them, we could not let people write single kernels which would work everywhere (and indeed, in the past, people would write code generators to generate repetitive kernels for every function). With template-based boxing and unboxing, you can write a single boxed kernel, and then have it work for operators, even if those operators are defined externally from the library.

So that's PyTorch's dispatcher in a nutshell! The dispatcher is still being continuously worked on; for example, Ailing Zhang recently landed a rework of how autograd dispatch keys are handled, which means that we actually no longer have a single Autograd key but have split autograd keys for AutogradCPU/AutogradCUDA/... We're generally interested in improving the user experience for people who register kernels to the dispatcher. Let us know if you have any questions or comments!
by Edward Z. Yang at September 10, 2020 06:29 PM
For the longest time, I thought of implicit parameters and dynamic scoping were basically the same thing, since they both can be used to solve similar problems (e.g., the so called "configuration problem" where you need to plumb down some configuration deep into a nested body of function definitions without defining them all explicitly). But implicit parameters have a reputation of being something you shouldn't use (use reflection instead), whereas dynamic scoping via the reader monad is a useful and well understood construct (except for the bit where you have to monadify everything). Why the difference?
Oleg points out that implicit parameters are not really dynamic scoping, and gives an example where Lisp and Haskell disagree. And you don't even want the Lisp behavior in Haskell: if you think about the operational notion of dynamic scoping (walk up the stack until you find a binding site of the dynamic variable), it's not very compatible with laziness, since a thunk (which accesses a dynamic variable) will be forced at some unpredictable point in program execution. You really don't want to have to reason about where exactly a thunk will be executed to know how its dynamic variables will be bound, that way lies madness. But somehow, in a strict language, no one has trouble figuring out what should happen with dynamic scoping (well, mostly--more on this shortly).
It turns out that the research community has figured out the difference is that implicit parameters are a coeffect. I believe this was first observed in Coeffects: Unified static analysis of context-dependence (a more modern presentation is in Coeffects: A calculus of context-dependent computation; and a more Haskelly presentation can be found in Embedding effect systems in Haskell). Although, Tomas was commenting on my blog in 2012 about similar ideas, so this probably had been in the works for a while. The key point is that for some coeffects (namely, implicit parameters), call-by-name reduction preserves types and coeffects, and so implicit parameters do not blow up in your face in the same way dynamic scoping (an effect) would. These necessarily behave differently! Type classes are coeffects too, and this is why modern use of implicit parameters in Haskell explicitly acknowledges this (e.g., in the reflection package).
At this year's ICFP, I was pointed at an interesting technical report about implicit values and functions in Koka, a new twist on the dynamic scoping. I found myself wondering if Haskell implicit parameters could learn a thing or two from this work. Implicit values make the good choice of defining implicit values globally at the top level, so that they can participate in normal module namespacing, as opposed to an un-namespaced bag of dynamically scoped names (this is also an improvement that reflection makes over implicit parameters). But actually, it seems to me that implicit functions are taking a page from implicit parameters!
The big innovation is the implicit function is that it resolves all dynamic references in the function (not just lexically, but for all further dynamic calls) to the lexical scope (the dynamic scope at the time the function was defined), producing a function that has no dependence on implicit values (aka, has no effect saying that the implicit value must be defined at the time the function is called.) This is exactly what an implicit parameter let ?x = ... binding would have done, in effect directly filling in the dictionary for the implicit function at definition site, rather than waiting. Very contextual! (Of course, Koka implements this using algebraic effects, and gets to the right semantics with a very simple translation anyway). The result is not exactly dynamic scoping, but as the TR says, it leads to better abstraction.
It is difficult to see how implicit values/functions could make their way back into Haskell, at least without some sequencing constructing (e.g., a monad) lurking around. Though implicit functions behave much like implicit parameters, the rest of the dynamic scoping (including the binding of the implicit function itself) is just good old effectful (not coeffectful) dynamic scope. And you can't just do that in Haskell, without breaking type preservation under beta-reduction and eta-expansion. Haskell has no choice but to go all the way, and once you get beyond the obvious problems of implicit parameters (which reflection fixes), things seem to mostly work out.
by Edward Z. Yang at August 27, 2020 05:51 AM
Summary: Read about my efforts to solve the game of Ultimate Tic Tac Toe. It’s been a fun journey into interesting algorithms and high-performance parallel programming in Rust. Backstory Starting around the beginning of the COVID-19 lockdown, I’ve gotten myself deeply nerdsniped by an attempt to solve the game of Ultimate Tic Tac Toe, a two-level Tic Tac Toe variant which is (unlike Tic Tac Toe) nontrivial and contains some interesting strategic elements.
I've recently been working on a revamp of how we specify tensor shape formulas in PyTorch. As part of this process, I classified every single operator in PyTorch by its shaping behavior; yes, that's all 1364 of them (this includes each variant of an operator; e.g., inplace and out= keyword variants). During the process, I tried to come up with categories to help classify what operators did. One of the surprises from the process was discovering that shaping behaviors that I previously thought were uncommon, actually showed up a bit more often than one might have expected.
These categories are interesting in their own right and can be used to help understand how PyTorch's API fits together. Here are all the categories I devised.
TensorIterator (505, e.g., add, sum) operators are PyTorch's bread and butter; these operators do pointwise operations and reductions and support broadcasting and type promotion. The name TensorIterator refers to an internal abstraction we have in PyTorch for implementing these operations; you can read more about it on the wiki and in this blog post. TensorIterator is a real workhorse in PyTorch: the plurarity (though not majority) of operators are implemented in this way! Note that this category includes some functions that used equivalent, legacy functionality (but did not exactly use TensorIterator).
Fixed (273, e.g., convolution, addbmm) operators are operators which only work on a fixed number of dimensions. This assumption makes writing efficient kernels a lot easier, as indexing math is simple with fixed dimensionality. (For example, TensorAccessor is an internal class which lets you view a tensor at fixed dimensionality known at compile time). Sometimes, the first dimension is treated as a batch dimension, but not always (unfortunately, I didn't distinguish these cases in my dataset). Some fixed operators actually support multiple dimensions, but only a fixed number of them; for example, because we only support 1-3D convolutions, this counts as fixed. (Compare with this FeatureBatched, below!)
N-Dimensional (107, e.g., squeeze, index_add, tensordot) operators are operators which work generically on tensors of arbitrary dimensionality. These are the operations for which it is difficult to write generic shaping rules for in symbolic form, as you need a language that can talk about list manipulations. An important subclass of N-dimensional operators are Identity (42, e.g., clone, contiguous; not included in the count above) operators work over arbitrary dimensionality, but they always return a tensor with the same size as their input. Another subclass are Flatten (11, e.g. take, bucketize) operators which accept tensors of any dimensionality, but always treat them as 1D tensors internally.
Composite (95, e.g., kl_div, isfinite) operators are implemented in other operators, and don't themselves have shape checking (instead, they rely on the operations they call to check shapes). Note this category is probably a bit underreported, as in some cases when it was obvious what the underlying behavior of an operator was, I classified the operator as that category, rather than Composite.
Batched (94, e.g., nll_loss, adaptive_avg_pool2d) operators are like fixed dimensionality operators, except they accept an arbitrary number of batch dimensions at their beginning. Many fixed operators should be batched operators; others cannot be converted into batched operators without introducing ambiguity as to where the batch dimensions end. Compare these with FeatureBatched (19, e.g., batch_norm, embedding) operators, which are like batched operators, but rather than accept batch dimensions at the beginning, they accept an arbitrary number of feature dimensions at the end.
Factory (90, e.g., empty) operators produce new tensors without having any tensor inputs.
Trivial (59, e.g., size, is_floating_point) operators aren't actual tensor operations, but ways to return non-Tensor information or access internal data structures
Sparse (40) operators are special because their size calculations take account of both dense and sparse dimensions.
Dynamic (15, e.g., unique) operators produce outputs whose shapes depend on the data of their input tensors
Variadic (14, e.g., cat) operators take multiple input tensors; similar to n-dimensional operations they are difficult to capture symbolic
You can take a look at the full data set at https://docs.google.com/spreadsheets/d/e/2PACX-1vQQFW0T_bucT5KZn0BHYTC1KYhkL6ZMG5ZxQWc6UmAkHUDYpqkpzXnsb59uv2TB0Jgc1Q6qO63bx6WQ/pubhtml
by Edward Z. Yang at May 06, 2020 03:56 PM
Alex Gaynor recently asked this question in an IRC channel I hang out in (a channel which contains several software engineers nearly as obsessed with software testing as I am): uhh, so I’m writing some code to handle an econnreset… how do I test this? This is a good question! Testing ECONNRESET is one of those fiddly problems that exists at the interface between systems — in his case, with S3, not even a system under his control — that can be infuriatingly tricky to reproduce and test.
Suppose we have some codebase we’re considering applying some patch to, and which has a robust and maintained test suite. Considering the patch, we may ask, is this patch acceptable to apply and deploy. By this we mean to ask if the patch breaks any important functionality, violates any key properties or invariants of the codebase, or would otherwise cause some unacceptable risk or harm. In principle, we can divide all patches into “acceptable” or “unacceptable” relative to some project-specific notion of what we’re willing to allow.
Last week, I wrote about the mindset that computer systems can be understood, and behaviors can be explained, if we’re willing to dig deep enough into the stack of abstractions our software is built atop. Some of the ensuing discussion on Twitter and elsewhere lead me to write this followup, in which I want to run through a few classes of systems where I’ve found pursuing in-detail understanding of the system wasn’t the right answer.
Introduction This post attempts to describe a mindset I’ve come to realize I bring to essentially all of my work with software. I attempt to articulate this mindset, some of its implications and strengths, and some of the ways in which it’s lead me astray. Software can be understood I approach software with a deep-seated belief that computers and software systems can be understood. This belief is, for me, not some abstruse theoretical assertion, but a deeply felt belief that essentially any question I might care to ask (about computers) has a comprehensible answer which is accessible with determined exploration and learning.
At this point in my career, I’ve worked on at least three projects where performance was a defining characteristic: Livegrep, Taktician, and Sorbet (I discussed sorbet in particular last time, and livegrep in an earlier post). I’ve also done a lot of other performance work on the tools I use, some of which ended up on my other blog, Accidentally Quadratic. In this post, I want to reflect on some of the lessons I’ve learned while writing performant software, and working with rather a lot more not-so-performant software.
vmap is an interface popularized by JAX which offers you a vectorizing map. Semantically, a vmap is exactly equivalent to a map in Haskell; the key difference is that operations run under a vmap are vectorized. If you map a convolution and a matrix multiply, you will have one big loop which repeatedly calls convolution and matrix multiply for each entry in your batch. If you vmap a convolution and matrix multiply, you'll call the batched versions of convolution and matrix multiply once. Unless you have a fuser, on most modern deep learning frameworks, calling the batched implementations of these operations will be much faster.
JAX implements vmap in a somewhat complicated fashion; they have a "batched interpreter" which translates operations on primitives into their batched versions, and have to track metadata about what tensors are batched and in what way so that they can insert appropriate broadcasts and unsqueezes. I mentioned this to Simon Peyton Jones, and he immediately asked, couldn't Haskell's typechecker work this out automatically? The answer is, yes! All of the book-keeping JAX has to do is effectively doing runtime type inference; if you have a compiler that can do it for you at compile time, there is nearly nothing to implement.
To give away the punchline, we are going to implement a family of functions vmap that will run these two examples:
example1 :: [Float] -> [Float] -> [Float] example1 a0 b0 = vmap0_2 (\a b -> add a b) a0 b0 example2 :: [Float] -> [Float] -> [[Float]] example2 a0 b0 = vmap0 (\a -> vmap1 (\b -> add a b) b0) a0
When run in an interpreter, we will see:
*Test> example1 [1,2,3] [4,6,8] [5.0,8.0,11.0] *Test> example2 [1,2,3] [4,6,8] [[5.0,7.0,9.0],[6.0,8.0,10.0],[7.0,9.0,11.0]]
These results are equivalent to what you would have gotten using a plain old map; however, there will be no loop in the implementation of vmap. (The fact that we can't write a single vmap that works universally is due to a limitation in Haskell; we'll discuss this more later.)
We're going to need a few language extensions, so let's get this out of the way first:
{-# LANGUAGE RankNTypes, GADTs, MultiParamTypeClasses,
KindSignatures, TypeApplications, FunctionalDependencies,
FlexibleContexts, FlexibleInstances, UndecidableInstances,
IncoherentInstances #-}
Our plan of attack is that we want to write the definitions of vmap so that we infer a type for add which makes the necessary broadcasting clear. A trivial implementation of vmap would have the signature ([a] -> [b]) -> [a] -> [b] (aka the identity function), but the standard list type doesn't let us distinguish between dimensions we should broadcast together, and dimensions we shouldn't (this is the reason example1 and example2 give different results: in example2, we broadcast along each dimension separately, so that we end up with a cartesian product in the end; in example1, we broadcast the dimensions together and get the zippy behavior). Each distinct invocation of vmap should give us a new dimension, which ought not to be mixed up with other invocations of vmap. When you hear this in Haskell, your first instinct should be, "I know, let's use a rank 2 type!" vmap moves us from the non-type-branded world of vanilla lists [Float] to a type-branded world of size-indexed vectors Vec s Float, where the s variables are all skolem variables bound by our rank 2 type:
data Vec s a = Vec { unVec :: [a] }
instance Functor (Vec s) where
fmap f (Vec xs) = Vec (map f xs)
vmap0 :: (forall s. Vec s a -> Vec s b) -> [a] -> [b]
vmap0 f = unVec . f . Vec
The implementation of vmap0 doesn't do anything: we just wrap the lists into their type-branded equivalent vectors. We can also provide a 2-ary version of vmap0, which takes two lists and assigns them the same type branding all at once:
vmap0_2 :: (forall s. Vec s a -> Vec s b -> Vec s c) -> [a] -> [b] -> [c] vmap0_2 f a b = unVec (f (Vec a) (Vec b))
(In principle, some sort of applicative-y thing should make it possible to write just a vap (analogous to ap) and then get all of the n-ary versions for free, but in my brief investigation I didn't see a good way of doing this.)
When we nest vmap, it may be the case that the function doesn't directly return a Vec s b, but a functor containing Vec s b. vmap1 handles this case (we'll discuss this more shortly):
vmap1 :: Functor f => (forall s. Vec s a -> f (Vec s b)) -> [a] -> f [b] vmap1 f = fmap unVec . f . Vec
With our implementations of vmap in hand, we can take a look at our examples and ask Haskell what the type of add ought to be, if we didn't have an implementation of it:
example1 :: [Float] -> [Float] -> [Float] example1 a0 b0 = vmap0_2 (\a b -> _add a b) a0 b0
Gives:
• Found hole: _add :: Vec s Float -> Vec s Float -> Vec s Float
Where: ‘s’ is a rigid type variable bound by
a type expected by the context:
forall s. Vec s Float -> Vec s Float -> Vec s Float
However:
example2 :: [Float] -> [Float] -> [[Float]] example2 a0 b0 = vmap0 (\a -> vmap1 (\b -> _add a b) b0) a0
Gives:
• Found hole:
_add :: Vec s Float -> Vec s1 Float -> Vec s (Vec s1 Float)
Where: ‘s1’ is a rigid type variable bound by
a type expected by the context:
forall s1. Vec s1 Float -> Vec s (Vec s1 Float)
at test.hs:41:20-44
‘s’ is a rigid type variable bound by
a type expected by the context:
forall s. Vec s Float -> Vec s [Float]
at test.hs:41:7-48
Notice that the inferred types of _add are different in these two cases: in the first example, we infer that we have two tensors batched in the same way, and we want to "zip" them together. In the second example, we see that each tensor has a distinct batch dimension, and we end up with a 2-D result!
At this point, the job of vmap is done: our holes have types which we can use to determine what the necessary behavior is. You could use these types to select an appropriate kernel to perform vectorized addition. But I promised runnable code, so let's implement a simple version of add using old fashioned map.
The good old fashioned way to do type level computation in Haskell is with a type class, of course! Let's define a multi-parameter type class for the function add; unlike the definition of (+) in Num, we'll let the inputs and output all have different types:
class Add a b c | a b -> c where add :: a -> b -> c
We can easily implement addition on plain floating point:
instance Add Float Float Float where add = (+)
If I pass add two arguments whose outer-most vector agree in their type brand (aka, they came from the same vmap), I should zip them together, as I did in example1. I can write another instance to express this logic:
instance Add a b r => Add (Vec s a) (Vec s b) (Vec s r) where add (Vec a) (Vec b) = Vec (zipWith add a b)
Otherwise, I should broadcast one of the dimensions and then do an addition on the inside. This choice can't easily be made locally, so I have to define these two incoherent instances:
instance Add a b r => Add (Vec s a) b (Vec s r) where add (Vec a) b = Vec (map (\x -> add x b) a) instance Add a b r => Add a (Vec s b) (Vec s r) where add a (Vec b) = Vec (map (\x -> add a x) b)
(GHC's type class resolution engine doesn't backtrack, so I'm not actually sure how it manages to pick the correct instance to use, but in my testing, I got the right instance no matter what order I specified the arguments to add.)
That's it! Running the two examples:
example1 :: [Float] -> [Float] -> [Float] example1 a0 b0 = vmap0_2 (\a b -> add a b) a0 b0 example2 :: [Float] -> [Float] -> [[Float]] example2 a0 b0 = vmap0 (\a -> vmap1 (\b -> add a b) b0) a0
I get:
*Test> example1 [1,2,3] [4,6,8] [5.0,8.0,11.0] *Test> example2 [1,2,3] [4,6,8] [[5.0,7.0,9.0],[6.0,8.0,10.0],[7.0,9.0,11.0]]
So there you have it! vmap in less than a dozen lines of Haskell. One unsatisfactory thing about this implementation is the necessity to define vmap0, vmap1, etc. Can't we just define a generic vmapG :: (forall s. Vec s a -> f (Vec s b)) -> [a] -> f [b] and have f unify with, well, the identity type lambda /\a. a when we need it to have the type of vmap0? Regretfully, type inference with type lambdas is undecidable (the so-called higher-order unification problem), so it seem we have to help GHC out here, even though in our particular case the unification we can do here is very restricted.
by Edward Z. Yang at January 29, 2020 07:14 PM
This is the second in an indefinite series of posts about things that I think went well in the Sorbet project. The previous one covered our testing approach. Sorbet is fast. Numerous of our early users commented specifically on how fast it was, and how much they appreciated this speed. Our informal benchmarks on Stripe’s codebase clocked it as typechecking around 100,000 lines of code per second per core, making it one of the fastest production typecheckers we are aware of.
Testing and feedback loops This post tries to set out one mental model I have for thinking about testing and the purpose testing serves in software engineering, and to explore some of the suggestions of this model. As mentioned in an earlier post, I think a lot about working in long-lived software projects that are undergoing a lot of development and change. The goal when working on these projects is not just to produce a useful artifact at one time, but to maintain and evolve the project over time, optimizing for some combination of the present usefulness of the software, and our ability to continue to evolve and improve it into the future.
This is a list of reasons why I think Python is a terrible programming language. Naturally, none of them apply to other programming languages.
It's impossible to determine at compile time whether a Python program can terminate.
While strings in Python 3 are significantly better than strings in Python 2, they consist of a series of "code points" instead of characters, so there's no way to reference the third character of a Python string. For ASCII strings, this works fine in C, so I prefer writing in C. Python claims to support Unicode but can't even get this right.
If I compile a Python module on Linux, it doesn't work on a Mac. This is a shocking oversight for a so-called "cross-platform" language.
The os.link function does not let you create a hard link to a directory.
The standard library sorting function takes at least O(n log n) time. I understand dynamic languages are slow, but why is it this slow?
The cryptography module claims to implement public-key cryptography, but mathematicians have so far been unable to prove that one-way functions, a prerequisite of public-key cryptography, even exist. It's surprising that the Python ecosystem is of such low quality and holds to dubious scientific standards.
When you compile NumPy or SciPy from source, you need to build FORTRAN code, and FORTRAN sucks which is clearly Python's fault.
If you're running two Python programs from different users on the same machine, a speculative load in one program might be able to learn information from the other program based on cache timing side channels.
Python is unable to reverse entropy.
by Geoffrey Thomas at December 18, 2018 12:00 AM
The git bisect command helps you identify the first commit in a range that broke something. You give it a good commit and a bad one, and it will do a binary search between the two to find the first bad commit. At each step, you say either git bisect good or git bisect bad depending on whether it passes your test, and it will move you halfway through the remaining commits in the range.
There are several guides for using git bisect with the Linux kernel (e.g., upstream, Gentoo, and Ubuntu all have one). Unfortunately, they're pretty time-intensive operations; they all say something to the effect of, "now build the kernel, reboot into it, and test it, then type git bisect good or git bisect bad depending on whether it worked." For a tricky hardware compatibility bug, this might be your only option. But if you're testing something about the kernel's behavior, this is unnecessarily slow and manual, and you might be tempted to do something else, like read commit logs.
At work a few days ago, someone reported that a certain application no longer worked in a new VM. After some initial debugging with strace, we determined that the program was calling the listen system call with a backlog of 0: that is, it was saying it was willing to accept up to zero connections. By the specification, it shouldn't work—and yet it did work on their older VM. A few things were different between the new systems, but one notable one was that the new VM had kernel 4.9 and the old one kernel 4.1. (Another was that it was deployed in a new cloud environment that my team is responsible for, with some networking changes, so we wanted to ensure we had not broken anything!)
I tried reading through git log --grep listen v4.1..v4.9 net/, but there was entirely too much and I couldn't find anything. So I decided to see if bisection could help me, with the use of git bisect run, which enables fully automated bisecting. I wasn't excited about rebooting my machine to do a binary search across eight kernel releases, but if I could get it to run in some other way, I could just leave it running.
For a normal program, it's pretty easy to use git bisect run, which just wants a command that returns success (0) or failure (1): you can usually do something like git bisect run make test. For a kernel regression, though, we'll need a command to boot the kernel and run some code. We can use the qemu virtual machine software for this, which has two properties that make it particularly suitable as such a command: it can boot a Linux kernel directly, instead of emulating a bootloader on a hard disk, and it can run a temporary VM in a single command line without any additional setup.
We'll build ourselves a tiny "initrd" (initial RAM disk), which is what's commonly used to load enough drivers to access your hard drive and completely boot your system. However, our initrd will just contain our one test program, which will possibly print a success message, and shut down the system. We can't meaningfully get a return value out of qemu, so we'll just grep its output for the success message.
The first step is to check out the kernel sources, if we don't have them already, and build a kernel:
$ git clone https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux-stable.git linux
$ cd linux
$ make defconfig
$ make -j8
which prints this after a while:
Kernel: arch/x86/boot/bzImage is ready (#21)
Then make sure we can boot our new kernel with qemu, without an initrd:
$ qemu-system-x86_64 -nographic -append console=ttyS0 -kernel arch/x86/boot/bzImage
That is, run it in text mode with the VM's serial console on standard input/output instead of trying to pop up a graphical window, and tell the kernel to use the serial port for console output. If your system supports it, you can add -enable-kvm to make it a faster, although since we want to shut down the VM immediately once we run our test, it doesn't make a huge difference (2 seconds vs. 4 on my machine).
This will panic, because we gave the kernel neither a root filesystem nor a working initrd. (You can kill the VM by typing Ctrl-A and then X.) So let's write an initrd with a single binary, init. It needs to shut down the system, so we get back to our prompt:
$ mkdir initrd
$ cd initrd
$ cat > init.c << EOF
#include <sys/reboot.h>
#include <stdio.h>
#include <unistd.h>
int main(void) {
printf("Hello world!\n");
reboot(RB_POWER_OFF);
}
EOF
(Yes, the system call for shutting down the system is named "reboot", because the name "shutdown" was already used for the system call to close a socket. I guess early UNIX computers didn't support initiating a hardware poweroff from software, so the shutdown command would just stop all processes, sync and unmount disks, and print a message asking the operator to cut power.)
Compile this program statically, so it's a single binary, put it in the particular form required for an initrd (a compressed cpio archive, an old but very simple format with a weird command-line tool) and make sure it's named init, and then we can boot it up with qemu:
$ cd initrd
$ cc -static -o init init.c
$ echo init | cpio -H newc -o | gzip > initrd.gz
1621 blocks
$ cd ..
$ qemu-system-x86_64 -nographic -append console=ttyS0 -kernel arch/x86/boot/bzImage -initrd initrd/initrd.gz
...
[ 0.502593] ALSA device list:
[ 0.502889] No soundcards found.
[ 0.503554] Freeing unused kernel memory: 1088K (ffffffff81f2f000 - ffffffff8203f000)
[ 0.504262] Write protecting the kernel read-only data: 14336k
[ 0.505004] Freeing unused kernel memory: 1680K (ffff88000185c000 - ffff880001a00000)
[ 0.505855] Freeing unused kernel memory: 1340K (ffff880001cb1000 - ffff880001e00000)
Hello world!
[ 1.089618] input: ImExPS/2 Generic Explorer Mouse as /devices/platform/i8042/serio1/input/input3
[ 1.092997] ACPI: Preparing to enter system sleep state S5
[ 1.094083] reboot: Power down
Great. We've built our own kernel, passed it a test binary to run, and got it booted in a qemu command that exits. This is turning into something we can pass to git bisect run. Now it's time to write the actual test. Here's what I ultimately ended up with to track down my bug:
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <sys/reboot.h>
#include <sys/ioctl.h>
#include <net/if.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <fcntl.h>
#include <stdio.h>
#include <unistd.h>
int main(void) {
/* The problem I was tracing was only reproducible with syncookies
disabled. While the initrd gets unpacked into a writable temporary
filesystem, nothing exists yet, so if I need /proc, I need to create
and mount it myself. */
if (getpid() == 1) {
mkdir("/proc");
mount("proc", "/proc", "proc", 0, NULL);
char buf[] = "0\n";
int fd = open("/proc/sys/net/ipv4/tcp_syncookies", O_WRONLY);
write(fd, buf, 2);
close(fd);
}
int server = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
/* Also, while a loopback ethernet device exist, it's not
enabled, so network tests won't work. This code is equivalent
to `ifconfig lo up`. */
struct ifreq ifreq = {
.ifr_name = "lo",
};
ioctl(server, SIOCGIFFLAGS, &ifreq);
if (!(ifreq.ifr_flags & IFF_UP)) {
ifreq.ifr_flags |= IFF_UP;
ioctl(server, SIOCSIFFLAGS, &ifreq);
}
struct sockaddr_in addr = {
.sin_family = AF_INET,
.sin_port = htons(54321),
.sin_addr = {htonl(INADDR_LOOPBACK)},
};
bind(server, (struct sockaddr *)&addr, sizeof(addr));
listen(server, 0);
int client = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
struct timeval timeout = {3, 0};
setsockopt(client, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout));
if (connect(client, (struct sockaddr *)&addr, sizeof(addr)) == 0) {
printf("Success\n");
} else {
perror("connect");
}
if (getpid() == 1) {
reboot(RB_POWER_OFF);
}
return 0;
}
Most of it is specific to the thing I was trying to test, but you may also need the code to create and mount /proc or to enable lo. Also, I put a few things conditional on getpid() == 1 so that I could safely test the program on my host system, where it wasn't running as root and where I didn't want it powering anything off. (I ran it a few times under strace to make sure it was doing what I expected it to do, and I didn't want to bother with getting strace inside my initrd.)
So I first made sure this is reproducible on a stock kernel by itself, isolated from any config my workplace might add:
$ qemu-system-x86_64 -nographic -append console=ttyS0 -kernel arch/x86/boot/bzImage -initrd initrd/initrd.gz | grep ^Success
$ git checkout v4.1
$ make defconfig && make -j8
$ qemu-system-x86_64 -nographic -append console=ttyS0 -kernel arch/x86/boot/bzImage -initrd initrd/initrd.gz | grep ^Success
Success
Cool, it's definitely a regression somewhere between those versions. (The set of config options change from kernel version to kernel version, so across this wide of a range, the easiest thing is to just get the current kernel's default config - if you need custom config options, you might want to edit .config after running make defconfig or something.) Now time to let git bisect run do its thing:
$ git bisect start
$ git bisect bad v4.9
$ git bisect good v4.1
$ git bisect run sh -c 'make defconfig && make -j8 && qemu-system-x86_64 -nographic -append console=ttyS0 -kernel arch/x86/boot/bzImage -initrd initrd/initrd.gz | grep ^Success'
It started printing a bunch of build logs and I went to work on something else. About half an hour later (I expected it to take longer!), it prints this out:
ef547f2ac16bd9d77a780a0e7c70857e69e8f23f is the first bad commit
commit ef547f2ac16bd9d77a780a0e7c70857e69e8f23f
Author: Eric Dumazet <edumazet@google.com>
Date: Fri Oct 2 11:43:37 2015 -0700
tcp: remove max_qlen_log
This control variable was set at first listen(fd, backlog)
call, but not updated if application tried to increase or decrease
backlog. It made sense at the time listener had a non resizeable
hash table.
Also rounding to powers of two was not very friendly.
Signed-off-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
$ git describe --contains ef547f2ac16bd9d77a780a0e7c70857e69e8f23f
v4.4-rc1~141^2~238^2~2
which looks awfully relevant—it implies they were previously rounding off the backlog. Looking at commit, we can see what happened: before kernel 4.4, the backlog argument was always capped to at least 8, and also rounded up to the next power of two. So listen(fd, 0) was turning into listen(fd, 8) on older kernels, and the program previously worked despite using listen() incorrectly. This commit was actually somewhere in the git log I was trying to read, but I must have scrolled past it.
git reflog shows that git bisect went through sixteen commits before settling on this one: it found this one on its 11th try, and then spent 5 more commits confirming that all the commits before this one were good. So I'm glad git bisect run found this commit, and I'm especially glad it found it in half an hour unattended, without me having to manually compile and test sixteen kernels by hand.
by Geoffrey Thomas at March 02, 2018 12:00 AM
When I first arrived at MIT in 2010, I was awed by how much computing infrastructure was readily available. Dorm rooms had 100 megabit internet connections, the at-the-time latest 802.11n Wi-Fi blanketed campus, and setting up a website took a few clicks.
![]() |
| MIT sold off parts of 18/8 in April 2017. Map of the Internet from xkcd, CC BY-NC 2.5 |
![]() |
| Amazon now owns 18.145/16. |
The Library has the entire Net 18 address space registered at many hundreds of publishers of licensed e-resources. With no prior notice, we have been forced into non-compliance with our licenses with every such provider.Although the ranges sold initially were unused, IS&T announced that the entire upper half of MITnet would be sold, and that buildings would need to be renumbered. Renumbering is always difficult and time-consuming, but a task that can be accomplished, given sufficient notice.
![]() |
| Slide from an IS&T presentation at the IT Partners conference, announcing the campus-wide NAT |
by Alex Chernyakhovsky (noreply@blogger.com) at June 25, 2017 03:02 PM
I've always heard stories about nasty Android malware that throws full-screen advertisements at inopportune times, but I'd never seen it myself until a family member handed me their Galaxy Note 5. The malware certainly exits, it triggers in any application and shows a full-screen advertisement, probably in a WebView. The recent applications list shows no identifying marks, and trying to switch to it just returns to the home screen. Figuring this one out is going to require adb.
by Alex Chernyakhovsky (noreply@blogger.com) at December 22, 2015 01:16 AM
If you're setting up a service where people can register their own usernames to be used as a hostname (username.example.com), email address (username@example.com), or URL path (example.com/username) within your domain, there are some common names you should avoid letting the general public register.
Many Internet protocols make the assumption that a domain is manually managed by its owners, and in particular assume that a name like admin must have been registered or approved by the actual owners. Automatic registration breaks this assumption, and has been the source of some attacks. Microsoft Live has fallen victim to this multiple times: in 2008, a researcher signed up for sslcertificates@live.com and used it to get a login.live.com certificate, and as late as this March, the same problem happened to live.fi, the Finnish version of the service, when an IT professional tried registering the email account hostmaster@live.fi as his personal Live account, and then found he could receive a certificate for that domain.
This is a list of all the names I know that should be restricted from registration in automated systems. If you know of others, please let me know and I'll update this page.
tl;dr: Regardless of how you're currently using usernames, restrict them to lowercase letters, digits, and hyphens, starting with a letter and not ending with a hyphen (that is, /^[a-z]([a-z0-9-]*[a-z0-9])?$/ as an extended regex). Ban all the names in this file (last updated 2015-11-21). Get yourself listed as a public suffix: see below for directions and implications.
Most of these problems involve a computer on the domain doing an unqualified lookup: when a computer named a.example.com looks for b, it will usually find b.example.com. If you're running a simple hosting service, or similar, you may not need to block all of these, but these names are extremely unlikely to be used by legitimate users anyway. So you may as well block all of them to allow expanding in the future.
localhost, localdomain, and broadcasthost: these are usually present in /etc/hosts, and applications or scripts might hard-code an assumption about them having their usual value (especially for localhost).www: Browsers will often prepend this if the domain itself does not resolve as a hostname.wpad: Web Proxy Auto-Discovery in several browsers; someone who owns this (unqualified) name can act as a proxy for all web traffic.isatap: IPv6 tunnel autodiscovery, primarily on Windows. Similarly to WPAD, someone who owns this (unqualified) name can act as a proxy for all IPv6-capable traffic. Windows Server has a built-in blacklist of domain names that defaults to WPAD and ISATAP.autoconfig: Thunderbird's spec for autoconfiguration. Thunderbird will query the website at autoconfig.example.com for settings when attempting to set up example.com email. Good way to harvest passwords.imap, pop, pop3, smtp, mail, for email clients that make guesses about what your email servers are. (This includes Thunderbird but also many others.)Note that valid hostnames are restricted in syntax: they must only contain letters, digits, or hyphens, and cannot start or end with a hyphen. DNS is case-insensitive, so make sure there are no case collisions. An older standard prevents hostnames from starting with a digit, which is a straightforward way to prevent all-numeric usernames (which can cause problems with tools that accept either names or UIDs). Dots separate portions of a domain name and cause various problems (wildcard certificates only apply to one level, a.b.example.com can read and write cookies for b.example.com, etc.), so they're usually more trouble than they're worth. DNS records are much more liberal, but names that don't follow these rules will generally not resolve as hostnames: you can look them up with dig/host/etc., but you can't use them in applications. Checking hostname syntax also prevents you from worrying about names like _tcp or _udp, which are used in SRV records.
Most parts of the web platform consider two pages with different origins, that is, scheme (http / https), hostname, and port number, to be unrelated websites that cannot interact with each other by default. However, there are a few exceptions, most notably cookies. Web pages at www.example.com and login.example.com are allowed to set cookies with a scope of example.com, despite not sharing the same hostname / origin. The simple rule of allowing parent domains created the problem of supercookies: example.com could set a cookie scoped to .com, which would then be sent to all sites ending in .com. There are two big problems with this: the first is privacy (being tracked across websites), and the second is session-fixation attacks, where an attacker can overwrite your session cookie with their own, and have your actions (including logging in or sending private data) happen within the attacker's session.
The immediate fix was to ban top-level domains, but this still allowed setting cookies for publicly-registrable suffixes like .co.uk that weren't at the top level. So browser vendors created the public suffix list to track which suffixes are open for public registration. The public suffix list now includes not only "ICANN" entries, such as .com and .co.uk, but also "private" entries, such as .herokuapp.com and .github.io, since the same problems exist with allowing users to set cookies for all Heroku or GitHub Pages users.
So, if you are letting users register hostnames in your domain, you should get it listed as a public suffix, which requires just sending a pull request or an email. It takes some time for the update to reach browsers (the list is compiled into browsers, so it's only updated by a browser version update), so you should try to do this as far in advance as possible before launching.
Note that by making example.com a public suffix, nobody, not even code on example.com itself, can set a cookie for example.com. If you have a website of your own that needs cookies (analytics, registration, etc.), you'll need to run it at e.g. www.example.com, and make example.com just a redirect. Alternatively, you can use a completely separate domain for your own site vs. your users' sites, as with the Heroku and GitHub examples: their own websites are heroku.com and github.com.
The CA/Browser Forum Baseline Requirements, section 3.2.2.4 item 4, requires that if a CA is going to validate a domain by coming up with an administrative email address on its own, it may only use admin, administrator, webmaster, hostmaster, or postmaster. Reserve all of those names, regardless of whether they go somewhere useful.
All CAs are supposed to be compliant with that these days, but for safety's sake, also reserve root, info, ssladmin, ssladministrator, sslwebmaster, sysadmin, is, it, and mis (see this 2009 comment on Mozilla's bug tracker).
RFC 2142 defines the names info, marketing, sales, support, abuse, noc, security, postmaster, hostmaster, usenet, news, webmaster, www, uucp, and ftp. You won't need most of these to actually reach a useful mailbox, though you should reserve all of them.
You may want to reserve mailer-daemon, nobody (a default UNIX user account), noreply, no-reply, etc. for automated processes that send email.
Again, as these names are unlikely to be used by legitimate users, it's usually worth blocking them now and keeping your options open, even if you're not currently offering email service. You may add an email service in the future (Amazon launched Send to Kindle by email over a decade after introducing user accounts). As always, you can manually register these names to trusted or internal users.
For many websites with user-provided content, like Twitter, Facebook, or GitHub, user-chosen usernames become part of the URL at top level (https://twitter.com/geofft, https://github.com/geofft). If you're building a website like this, the easiest approach is to restrict these usernames as if they were hostnames. This has two advantages: the first is that it's easy to launch a hostname-based system later (e.g. GitHub Pages now supports geofft.github.io) if you know that all your usernames are valid hostnames.
The second is that there are several URL paths you need to reserve at top level, and all of them happen to contain dots and are therefore invalid hostnames. If you do permit dots, you need to block the following names:
robots.txt, for the Robots Exclusion Protocol, used to tell well-behaved crawlers how to well-behave.favicon.ico, for the shortcut icon displayed in the tab bar and other places.crossdomain.xml, which allows the Flash plugin to make cross-origin requests. Java and Silverlight also look for and trust crossdomain.xml.clientaccesspolicy.xml, a Silverlight-specific version of crossdomain.xml..well-known, specified in RFC 5785 as a place for these sorts of things so they don't keep cluttering the root level. Thunderbird autoconfiguration looks in here, as do ACME, the automatic certificate enrollment spec from Let's Encrypt; BrowserID / Mozilla Persona; and RFC 7711, a new standard for providing certificates for third-party non-HTTP services. So there are a number of security issues with an unauthorized user being able to create files under /.well-known/.(These are URLs, not filenames. You should of course also disallow users from creating files named e.g. .htaccess if your web server respects those.)
All of these are invalid hostnames, so simply requiring usernames to be valid hostnames avoids having to check for these specific cases. If you're only allowing users to choose some portion of the URL, and inserting other text (e.g., example.com/user/geofft, example.edu/~geofft), then you don't have to worry about this, but again it may still be useful to keep your options open for other URL, hostname, or email schemes in the future.
Do not allow users to publish custom HTML, especially not custom scripts, at these sorts of URLs. https://example.com/user1, https://example.com/user2, and https://example.com/login all share the same origin, so by the same-origin policy, these web pages can freely interact with each other and mess with each other's content. A few JavaScript interfaces, including service workers, make it very easy to attack another site on the same origin. If you want users to be able to publish custom HTML and JS, use separate hostnames within a public suffix. https://user1.example.com and https://user2.example.com are separate origins, and if you have made example.com a public suffix as mentioned earlier, you can safely let them publish custom scripts, since the sites are no more able to interact with each other than two separate .com websites could.
This post was inspired by a GitHub issue for Sandstorm's sandcats.io dynamic DNS service; thanks to Asheesh Laroia for pointing me at that thread and reviewing a draft of this article.
by Geoffrey Thomas at November 26, 2015 12:00 AM
I ran across something strange while learning about Rust's stack overflow and segmentation fault handling.
First, some backstory: in the past, Rust (and Go) used segmented stacks, also known as split stacks. This is a scheme that allows you to start each thread with a small amount of stack space, and dynamically allocate a new, discontiguous chunk of stack space when you run out. Each function starts with a call to __morestack to see if it has enough stack space for that function's variables. If not, __morestack is supposed to allocate more space, quietly switch out the stack pointer, and call the function. When the function returns, __morestack frees the additional space and restores the original stack pointer. For a system like pre-1.0 Rust's tasks or Go's goroutines, allocating lots of tiny, growable stacks makes sense.
However, Rust's __morestack function currently only serves to trigger stack-overflow handling: the only thing that it does, besides aligning the stack properly, is call the rust_stack_exhausted function, which prints an error message and exits. Rust gives each thread plenty of stack space, so if it overflows, it probably means there's unbounded recursion and the thread should be terminated.
I thought this was a little odd. When you overflow the stack in C, or in any other language without __morestack, the next instruction tries to access unallocated memory. This causes a standard segmentation fault, which terminates the program. What's the need for Rust to catch this and terminate the program on its own?
Part of the answer is to provide a better error message. A straightforward stack overflow is not a memory safety violation in the same way an access to out-of-bounds memory is, so a segmentation fault is the wrong way to report a stack overflow. But the real answer turns out to be subtler than that. While accessing an invalid page of memory cannot cause data corruption, it's not guaranteed that the page is in fact invalid! With enough luck, you can overflow the stack far enough and reach into a completely unrelated memory region.
It turns out that this is possible in well-defined C, in a worrisomely straightforward way. The following program generates no warnings:
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
void clobber(uintptr_t target_address) {
int dummy;
uintptr_t dummy_address = (uintptr_t)&dummy;
char array[dummy_address - target_address];
array[50] = 'x';
(void)array; // suppress warning about unused variable
}
int main(void)
{
int i;
int *x = malloc(20 * sizeof(int));
for (i = 0; i < 20; i++)
x[i] = 3;
clobber((uintptr_t)x);
for (i = 0; i < 20; i++)
printf("%d ", x[i]);
printf("\n");
return 0;
}
and produces this output on my computer (an x86-64 machine running Debian GNU/Linux 8):
geofft@titan:/tmp$ gcc -Wall -Wextra -Wpedantic --std=c99 -o lol-stacks lol-stacks.c
geofft@titan:/tmp$ ./lol-stacks
3 3 3 3 7864323 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
What's going on? We create a variable on the stack called dummy, to figure out where our stack is. Then we figure out how far we are from the target address and create a giant array of exactly that size. Now the end of the stack lines up, more or less, with the target address. The process of "creating an array" doesn't involve any actual changes to memory, just changes to the interpretation of memory (and usually updating the stack pointer register), so while there's a wide swath of invalid memory between &dummy and target, none of it is accessed. Therefore, because the only access is to a valid memory location, there's no segmentation fault generated, and the change to memory goes through. When the program writes to array, it finds valid, writable memory at that location—namely, one of the elements in the target array x. (Remember that on x86 machines, the stack grows downwards, towards lower-valued addresses, so array[0] is the end of the stack. On other architectures, if the stack grows upwards, we'd need to write to the other end of array.)
As you might expect with memory corruption, this is exploitable. In 2010, Rafal Wojtczuk, a researcher at Invisible Things Lab, discovered that this sort of confusion could be used in a privilege-escalation exploit against the X server (PDF). A client can cause the X server to exhaust most of its memory, request a shared-memory region that gets mapped right past the stack, and then trigger a recursive function in the X server. As the associated blog post puts it, it "exploits a bug in... well, actually it doesn't exploit any concrete bug, which makes it so much more interesting."
How can we defend against this? The X server vulnerability (CVE-2010-2240) relied on the stack running right up against mapped pages, so the straightforward solution was to ensure that an unmapped guard page is always just past the end of the stack. But this doesn't help for more complicated cases that span more than a single page—like my example above, or even a less-contrived function with a 5000-byte string buffer.
It turns out that GCC has an option, -fstack-check, that inserts code to check for this problem. If I recompile my program with -fstack-check, I get a segmentation fault. It appears that gcc has inserted code to step the stack down one page at a time, running a logical-OR with zero at each point, which doesn't affect any value on the stack but forces a memory access. Here's the disassembly of that section of code, from GDB:
0x0000000000400621 <+139>: mov %rsp,%rcx
0x0000000000400624 <+142>: sub %rdx,%rcx
0x0000000000400627 <+145>: cmp %rcx,%rsp
0x000000000040062a <+148>: je 0x40063a <clobber+164>
0x000000000040062c <+150>: sub $0x1000,%rsp
=> 0x0000000000400633 <+157>: orq $0x0,(%rsp)
0x0000000000400638 <+162>: jmp 0x400627 <clobber+145>
The size of my variable-length array has already been computed in %rdx. GCC has inserted code to step through the stack one page ($0x1000 bytes) at a time, in a loop. At each step, it ORs the value at the stack pointer with zero, which doesn't change any data but forces a read and write of memory.
Why isn't this the default, like other safety mechanisms like -fstack-protector? I don't know for sure. Part of the answer could be performance, since each somewhat large stack allocation will result in this code being run—with overhead linear in the size of the allocation, instead of just a single change to the stack pointer—but I'm not sure if anyone expects large stack allocations to be performant. Another answer could be lack of interest, since GCC doesn't even enable it by default when compiling Ada code, even though the Ada language requires that stack overflows and segmentation faults be distinct errors.
This technique is common on Windows. Microsoft's rationale for stack checking is less about preventing unexpected changes to memory and more about making sure the stack is actually in place. Windows places a single guard page past the end of the stack, which is reserved in the memory map, but causes a fault when accessed. The kernel takes this as an indication that the process needs more stack space, and will allocate more physical pages, provided that resource limits allow it and there's enough free RAM. Then it moves the guard page past the new end of the stack. So, if your program has an object that is even slightly more than the page size (e.g., a char buffer[5000]), it's important that the guard page get touched, instead of skipped over. (Linux also uses a similar mechanism, but as far as I can tell, the entire possible stack region is treated as guard pages, so the problem of skipping over a single guard page doesn't arise.)
Rust is looking at moving away from the __morestack approach to this approach, under the name of stack probes (the term "probe" is also used in GCC documentation). Apparently the right name for this technique is a bit of a question. "Stack checking," as Microsoft and GCC call it, is a somewhat generic name, and I imagine many developers would interpret it as referring to the more typical direction of stack protection. Indeed, I was very surprised to see that the first C program above "worked", clobbering memory, even with -fstack-protector. It's common to see stack-smashing attacks, where an out-of-bounds pointer overwrites something valuable on the stack, and that's what -fstack-protector (StackGuard, which uses stack canaries) protects against. But it's not as common to see the inverse problem, where an out-of-bounds stack pointer overwrite something elsewhere in memory, which can be just as unsafe and just as much of a security risk.
Thanks to Nelson Elhage, Liz Denys, Alex Dehnert, and Asheesh Laroia for feedback on drafts of this post.
by Geoffrey Thomas at July 06, 2015 12:00 AM
UNIX (well, POSIX) signals are probably one of the worst parts of the UNIX API, and that’s a relatively high bar. A signal can be sent to your program at any point when it’s running, and you have to deal with it somehow. Traditionally, there are three ways of responding to signals (“dispositions”): you can let the default behavior run, which is often to kill the program; you can ignore it; or you can set up a function to be a signal handler.
If you register a signal handler, it’s called in the middle of whatever code you happen to be running. This sets up some very onerous restrictions on what a signal handler can do: it can’t assume that any locks are unlocked, any complex data structures are in a reliable state, etc. The restrictions are stronger than the restrictions on thread-safe code, since the signal handler interrupts and stops the original code from running. So, for instance, it can’t even wait on a lock, because the code that’s holding the lock is paused until the signal handler completes. This means that a lot of convenient functions, including the stdio functions, malloc, etc., are unusable from a signal handler, because they take locks internally. POSIX defines a set of “Async-Signal-Safe” functions that are guaranteed not to rely on locks or non-atomic data structures (because a signal handler could be called between two instructions that update a non-atomic data type). Linux’s signal(7) manual page documents has a list of these functions. But because of the constrained environment, one common approach is to defer the actual work until you’re no longer running in a signal handler. A standard solution is the “self-pipe trick”: open an unnamed pipe, write a byte to it in your signal handler and return, and read from it in the main loop. This makes signal handling resemble the handling of anything else that’s a file descriptor, like network sockets, which normalizes their handling a bit.
There’s another problem with signal handlers: in addition to interrupting user code, they also need to interrupt kernel code so that it gets delivered. So, if you’re in the middle of reading from a network socket, waiting for a timeout, or even closing a file descriptor, that system call gets aborted so that the signal handler can run. By default, once you exit the signal handler, your system call will appear to return early, with an error exit of errno == EINTR (“Interrupted system call”). The intention is that you can restart the call yourself if you want, but in practice, very few libraries do, which causes confusing errors for application developers. (Extra credit: explain why the linked solution is wrong.) To avoid these sorts of problems, BSD added a mechanism to mark signal handlers as restartable, which was standardized as the SA_RESTART flag to the sigaction system call. But that does not reliably cause system calls to continue instead of returning EINTR: as documented later in signal(7), Linux will return EINTR in some cases even when a system call is interrupted by an SA_RESTART handler. (The rationale for many of these makes sense—e.g., timeouts where restarting the call would start a timer from zero—but understanding the problem doesn’t make it any less of a problem.)
So, in 2007, Linux gained an API called signalfd, which lets you create a file descriptor that notifies on signals. The idea is that you can avoid the complexity of the self-pipe trick, as well as any problems with EINTR, by just asking the kernel to send you signals via a file descriptor in the first place. You don’t need to register a signal handler, so everything works perfectly… except that signalfd doesn’t actually change how signal dispositions work. If you only create a signalfd, signals still get delivered via the the signal-handling mechanism, so in order to actually get the signal to get delivered over the file descriptor, you need to suspend normal processing of the system. There’s another part of the signal API that lets you set a mask to block the signal: the intention is that you can remove the mask later, and then pending signals will get delivered. But this also means that pending signals are readable from the signalfd, and they’re removed from the queue once read.
This would work fine if signal masks applied only to your current process, but they also get inherited to children. So if you’re masking a signal like SIGINT in order to receive it via signalfd, and you run a subprocess, then that child process will start up with SIGINT masked. And while programs could reset their own signal masks when they start up, in practice, no software does. So you have to be very careful to reset any masked signals before starting a child process, and unfortunately you need to do this yourself: most ways of starting child processes, including the standard libc system(3) function, do not reset handlers or masks.
There’s another problem with masking signals, which is that standard UNIX signals are permitted to coalesce when they queue. Namely, if you have a mask that’s blocking a signal, and that signal gets delivered twice before you proces the first one from the queue, it only gets reported to you once. For something like SIGINT, this might be okay: probably you don’t need to handle multiple Ctrl-C keystrokes differently from a single one. For something like SIGCHLD, which notifies you that a child process has terminated, this is quite a bit more unfortunate. It now means that you know that at least one child process has terminated. And while signals can have associated info (via the SA_SIGINFO flag to sigaction, or by default for signalfd), and SIGCHLD’s siginfo tells you which process has terminated, if it gets coalesced, you only get one set of info. That is, you know that the specified child process has terminated, but also one or more other unspecified child processes could have terminated and the information about which one has been lost. Not only does this apply to unmasking signals and letting them get delivered to a normal signal handler, it also applies to reading signals from a signalfd, since signalfd works by dequeuing signals.
It turns out that there is, at least in theory, a better way to handle this. Just go back to the old self-pipe trick, but install the handler with the SA_NODEFER flag, which allows signal handlers to interrupt themselves. This way you reliably get one handler called per signal. [EDIT: This is wrong, see below.] This brings back the EINTR problem, but there’s a simple solution to that: simply dedicate a separate thread to signal handling, and make sure all signals get routed there. Then no system calls on any other thread will get interrupted. The only way to ensure this is to mask signals on every other thread, but that’s no worse than the status quo with signalfd, since we already had to mask signals in order to get them delivered to signalfd alone. As a bonus, this is portable to other operating systems that don’t have signalfd, so this is the solution I’m planning on implementing for the MIO cross-platform event-handling library for Rust.
Can signalfd be fixed? I think it can, if we explicitly tie signalfd into the signal-disposition mechanism. If signalfd could claim responsibility for signal delivery, instead of requiring that signals be masked or ignored in addition to using signalfd, this would solve both problems. For inheritance, just set the close-on-exec flag on the signalfd, and signals will go back to the default behavior in child processes, once the fd no longer exists. For multiple delivery, because signalfd no longer interacts with signal queueing, it can just be specfied to send one message per signal. Alternatively, a mechanism to mask or ignore a signal that applies to the current process only would also be an improvement. These proposals would be a change to POSIX signal semantics, but fundamentally so is signalfd, and the only way to get clean handling of signals is to avoid the POSIX API, not work within it.
EDIT: Greg Hudson pointed out to me that even traditional handlers can coalesce signals. If a signal is sent twice before the signal-handling thread is scheduled, it still gets coalesced, even if it's destined for a signal handler with SA_NODEFER. So while a signal-handling thread can get you separate siginfo most of the time, it still can’t do so reliably. I’ve written an example of this behavior that uses vfork to prevent the process from being scheduled: it starts multiple children, but only prints one notification. So my claim that there’s a better alternative to signalfd is wrong, but it’s still true that signalfd doesn’t solve any of the other troubles with signal handling: it just saves you a separate thread.
Thanks to Nelson Elhage for pointing me at the LWN article, and Alex Chernyakhovsky, David Benjamin, and Yehuda Katz for discussions about all the unfortunate parts of signal handling.
by Geoffrey Thomas at May 17, 2015 12:00 AM
The Name Service
Switch
(NSS) is the feature of many C libraries, including the standard Linux one
(GNU libc) and the standard Solaris one, that allows name lookup
routines to be implemented via plugins. "Name lookup" here refers to
things like usernames, host names, etc. When you run ls -l, and the
ls command maps numeric user IDs to usernames, the names are provided
by any module listed in the Name Service Switch's configuration. This
could be the files module that looks at local files like
/etc/passwd, the ldap module that talks to a corporate LDAP server,
etc.
One of my half-forgotten side projects involves writing a new NSS module.
NSS is configured via the file /etc/nsswitch.conf, which is systemwide
configuration. If I want to test my NSS module, I could install it
globally and reconfigure my system. But I started wondering if there was
a way to load an NSS module just for a single process under my control.
Since the process is running in my user account, I should be able to
reconfigure it, without affecting the rest of the system.
Turns out that it's possible in a somewhat hackish way. There's an
internal glibc function called __nss_configure_lookup that overrides
NSS settings. If you're writing your own test program, you can just
call, e.g., __nss_configure_lookup("passwd", "files") to force
all user lookups to go through libnss_files. If you're using an
existing program, you can shoehorn this in by use of an LD_PRELOAD:
#include <nss.h>
#include <stdlib.h>
static void __attribute__((constructor))
nsstest_ctor(void)
{
const char *db = getenv("NSSTEST_DB"), *config = getenv("NSSTEST_CONFIG");
if (db && config)
__nss_configure_lookup(db, config);
}
Compile with gcc -fPIC -shared -o nsstest.so nsstest.c. Then you can
do things like this:
howe-and-ser-moving:/tmp geofft$ ls -ld ~
drwxr-xr-x 355 geofft root 34816 Apr 17 21:25 /afs/athena.mit.edu/user/g/e/geofft
howe-and-ser-moving:/tmp geofft$ LD_PRELOAD=./nsstest.so NSSTEST_DB=passwd NSSTEST_CONFIG=files ls -ld ~
drwxr-xr-x 355 40490 root 34816 Apr 17 21:25 /afs/athena.mit.edu/user/g/e/geofft
Since my account isn't in the files database on this machine (it's in
hesiod), my user ID can no longer be looked up if I restrict passwd
lookups to the files database. A more straightforward way of testing
is using the getent command that ships with glibc, which lets you ask
for a specific entry in a specific NSS database. For instance, both
files and hesiod have entries for the root user:
howe-and-ser-moving:/tmp geofft$ LD_PRELOAD=./nsstest.so NSSTEST_DB=passwd NSSTEST_CONFIG=files getent passwd root
root:x:0:0:root:/root:/bin/bash
howe-and-ser-moving:/tmp geofft$ LD_PRELOAD=./nsstest.so NSSTEST_DB=passwd NSSTEST_CONFIG=hesiod getent passwd root
root:*:0:101:Wizard A Root,,,:/mit/root:/bin/csh
If you're writing your own NSS library, you'll also need to set
LD_LIBRARY_PATH to point to the directory where it lives, since NSS
configuration just takes names, not full paths.
by Geoffrey Thomas at April 18, 2015 12:00 AM
I just got back from PyCon, where a few debian-python team members met to discuss some goals for Python packaging in Debian over the next release cycle. One item of interest is moving away from installing Python 2 by default (expecting it to be desupported in 2020), which raises questions about what /usr/bin/python should mean. At the moment, it very strongly means Python 2 specifically, so there's a question about what it should be on a system with only Python 3 installed—should it become Python 3? Should it be nonexistent?
PEP 0394 recommends that Python 2 should be installed as python2 and Python 3 as python3, and that python should mean Python 2 for now. I think it's important that python continue to mean Python 2, for the reasons described in the "Migration Notes" section of that PEP. In particular, I think it's very important that you should be able to install Python 2 (for the near future), even if your system did not ship with Python 2, and that the system version of Python 3 should not prevent you from using python to run Python 2 applications.
However, not shipping /usr/bin/python by default also has its downsides. I made a suggestion on the debian-python list for handling this: this blog post is a more detailed version of that proposal.
The basic motivation for this proposal is that third-party script authors love Python in part because it's just about universally available. #!/usr/bin/env python, today, is sort of like #!/bin/sh: you don't necessarily get nice things (Python 3-only features and bash-only features, respectively), but it works. If the python command stops existing, this is inconvenient for end users, who will need to manually edit scripts or install a version of Python themselves, and authors, who may just decide to use a worse language like #!/bin/sh. It's also bad for the Python language community, since it removes an incentive to use Python.
So it would be nice to keep the python command working usefully on both Python 2-only systems and Python 3-only systems. Fortunately, it's pretty doable to port code to work on both Python 2 and 3 with the same source. Especially for third-party scripts with the goal of working out-of-the-box in as many places as possible, writing code to these restrictions isn't particularly onerous, since they needed to stay Python 2-compatible anyway. I've been writing a bunch of Python code recently as polyglot Python 2/3, and the biggest problem has been limiting myself to features in Python 2.7, not getting things to work in both versions of the language.
So here's the proposal: we install a wrapper binary as python, that defaults to launching Python 2 (or reporting an error if Python 2 is not installed). However, Python 2/3-compatible scripts can include a specific marker indicating they can be run on Python 3, which makes these scripts able to run on both Python 2-only systems and Python 3-only systems.
This marker is based on the existing coding: syntax from PEP 0263: it's a "magic comment" on the first or second line of the form pyversions=2.7+,3.3+, indicating which Python major and minor versions are supported. A script compatible with both Python 2 and 3 should include one of these magic comments, and also include a shebang line launching just python. Here's an example based on one from PEP 0263:
#!/usr/bin/env python
# -*- coding=utf-8 -*- pyversions=2.6+,3.3+
from __future__ import print_function
print(u"Hello world!")
On new systems, the python command itself is a wrapper binary that knows what versions of the Python interpreter are installed. If it detects a pyversions comment, it will exec the newest Python interpreter compatible with the comment. For this example, if Python 3.3 is installed, it will launch that: if only Python 3.2 and 2.7 are installed, it will use Python 2.7, since Python 3.2 does not support the u"" literal syntax. Otherwise, it will assume that code is Python 2-only, and exec the newest Python 2 version installed. In either case, if a compatible interpreter cannot be found, it will print an error and exit.
On legacy systems, the python command refers to Python 2. So Python 2/3-compatible scripts will still be able to find an interpreter they are compatible with.
For legacy scripts that use a shebang stating just python, on both new and legacy systems, they will only ever run on Python 2, or not at all. This preserves the existing API, and avoids the poor user experience of running Python 2 scripts with the Python 3 interpreter, as PEP 0394 warns against doing.
However, the python wrapper supports running Python 2/3-compatible scripts with the Python 3 interpreter, which is useful for Python 3-only systems. A future version of Debian can choose to ship the Python 3 interpreter only, and remain compatible with third-party scripts that include the pyversions magic comment. Meanwhile, these third-party scripts can state #!/usr/bin/env python in their shebang, and remain compatible with legacy Python 2-only systems, including those (like Mac OS X) that do not include the python2 symbolic link.
The python wrapper can run in three possible modes:
python script.py.python, without the pedagogical speed bump of having to explain versions of the language. Running the latest major version of Python is safe, since the interpreter will print out its own version at startup.python -c are considered scripted use, since in this context, the python command is also serving as an API, and existing users of the API expect Python 2 just as much as scripts do. (Imagine, for instance, a shell script with KEY=$(python -c "import local_settings; print SECRET_KEY")), which will fail if python means Python 3.) In this mode, the wrapper will instead look for an environment variable named PYVERSIONS. If this is set, it will be parsed as the value of the pyversions magic comment. So shell scripts that include Python 2/3-compatible python -c commands can e.g. export PYVERSIONS=2.7+,3.3+ at the top of the script, and work on systems with either just Python 2 or just Python 3.The end result here is that third-party script authors can continue writing polyglot Python 2/3 code for the indefinite future, and be compatible with both existing Python 2-only distributions and newer Python 3-only distributions. Meanwhile, distributions can move towards installing Python 3 only by default and avoid installing Python 2 as far as possible, without closing off the ability to install Python 2 when needed, or breaking legacy Python 2-only code.
This is a good transition story for distributions like Debian, Ubuntu, and Fedora that are currently trying to move to Python 3 only, but don't want to break existing code. It's also a good transition story for distributions like Arch that have already started moving /usr/bin/python to Python 3: interactive and scripted use of the python command can continue to launch Python 3, but compatibility with third-party Python 2 scripts is regained. And most importantly, it's good for third-party script authors, who want to write one script that works on Debian, Ubuntu, Fedora, Arch, and Mac OS X alike, and for their end users.
I'm interested in initial feedback on this idea: if it generally seems like a good plan, I'll firm up the details and write it up as a formal PEP. There are a few details to work out, like how this interacts with the py.exe Windows launcher described in PEP 0397 (if at all), but my focus is making /usr/bin/python useful and backwards-compatible on UNIX-like platforms.
Edit 3 May 2015: I've uploaded a proof-of-concept implementation of this launcher to GitHub to see what the overhead and complexity of this approach looks like.
by Geoffrey Thomas at April 17, 2015 12:00 AM
Powered by Planet Venus!
Last updated: April 11, 2026 08:07 PM