SAM 论文翻译+学习笔记
模型部分附上了代码细节,SAM模型架构在文末。 论文地址:https://arxiv.org/pdf/2304.02643.pdf
Abstract
We introduce the Segment Anything (SA) project: a new task, model, and dataset for image segmentation.
SA项目:一个图片分割场景的新任务、新模型、新数据集。
Using our efficient model in a data collection loop, we built the largest segmentation dataset to date (by far), with over 1 billion masks on 11M licensed and privacy respecting images.
构建了一个至今为止最大的图片分割数据集,1100万图片上,标了10亿个masks。
The model is designed and trained to be promptable, so it can transfer zero-shot to new image distributions and tasks.
模型是可提示的,所以可以zero-shot迁移应用到新的数据集或下游任务上。
We evaluate its capabilities on numerous tasks and find that its zero-shot performance is impressive – often competitive with or even superior to prior fully supervised results.
我们在很多任务上评估了zero-shot表现,它通常能具有竞争力或优于全监督结果,这很牛逼。
We are releasing the Segment Anything Model (SAM) and corresponding dataset (SA-1B) of 1B masks and 11M images at https://segment-anything.com to foster research into foundation models for computer vision.
我们release了模型SAM和SA-1B数据集。
1.Introduction
Large language models pre-trained on web-scale datasets are revolutionizing NLP with strong zero-shot and few-shot generalization [10].
在网络规模数据集上预训练的大型语言模型正在通过强大的零样本和少样本泛化彻底改变NLP [10]。
These “foundation models” [8] can generalize to tasks and data distributions beyond those seen during training.
这些“基础模型”[8] 可以泛化应用到训练时没见过的任务和数据分布上。
This capability is often implemented with prompt engineering in which hand-crafted text is used to prompt the language model to generate a valid textual response for the task at hand.
这种能力通常通过提示工程来实现,会手写一段文本,来提示语言模型为任务生成文本反馈。
When scaled and trained with abundant text corpora from the web, these models’ zero and few-shot performance compares surprisingly well to (even matching in some cases) fine-tuned models [10, 21].
当使用来自网络的丰富文本语料库进行扩展和训练时,这些模型的零样本和少样本性能出奇地好,甚至有些场景下能与微调模型相当。
Empirical trends show this behavior improving with model scale, dataset size, and total training compute [56, 10, 21, 51].
经验趋势表明,随着模型规模、数据集大小和总训练计算的增加,模型的零样本和少样本性能会越来越好[56,10,21,51]。
Foundation models have also been explored in computer vision, albeit to a lesser extent.
计算机视觉领域也对基础模型进行了探索,尽管程度较小。
Perhaps the most prominent illustration aligns paired text and images from the web.
可能最出名的例子就是对齐了来自网络的配对文本和图像。
For example, CLIP [82] and ALIGN [55] use contrastive learning to train text and image encoders that align the two modalities.
例如,CLIP [82] 和 ALIGN [55] 使用对比学习来训练对齐两种模式的文本和图像编码器。
Once trained, engineered text prompts enable zero-shot generalization to novel visual concepts and data distributions.
经过训练后,设计的文本提示可以零样本概括新的视觉概念和数据分布。
Such encoders also compose effectively with other modules to enable downstream tasks, such as image generation (e.g., DALL·E [83]).
这种编码器还可以与其他模块有效组合,以实现下游任务,例如图像生成(例如,DALL·E [83])。
While much progress has been made on vision and language encoders, computer vision includes a wide range of problems beyond this scope, and for many of these, abundant training data does not exist.
尽管在视觉和语言编码器方面取得了很大进展,但计算机视觉包含了超出此范围的广泛问题,并且对于其中许多问题,并不存在丰富的训练数据。
In this work, our goal is to build a foundation model for image segmentation.
在这项工作中,我们的目标是建立图像分割的基础模型。
That is, we seek to develop a promptable model and pre-train it on a broad dataset using a task that enables powerful generalization.
也就是说,我们要开发一个可提示的模型,在一个广泛的数据集上进行预训练,使用一个能够广泛泛化的task
With this model, we aim to solve a range of downstream segmentation problems on new data distributions using prompt engineering.
通过该模型,我们的目标是使用提示工程解决新数据分布上的一系列下游分割问题。
The success of this plan hinges on three components: task, model, and data. To develop them, we address the following questions about image segmentation:
决定计划成功的关键有三点:任务、模型、数据。我们解决了关于图像分割的如下三个问题:
-
What task will enable zero-shot generalization? 什么任务能支撑零样本泛化?
-
What is the corresponding model architecture? 对应的模型架构长什么样?
-
What data can power this task and model? 实现这样的task和model,需要什么数据?
These questions are entangled and require a comprehensive solution.
这些问题错综复杂,需要综合解决。
We start by defining a promptable segmentation task that is general enough to provide a powerful pre-training objective and to enable a wide range of downstream applications.
我们首先定义一个可提示的分割任务,该任务足够通用,有一个强大的预训练目标,并可以支持广泛的下游应用。
This task requires a model that supports flexible prompting and can output segmentation masks in real-time when prompted to allow for interactive use.
该任务需要一个支持灵活提示的模型,并且可以在提示下实时生成分段掩码,以支撑交互使用场景。
To train our model, we need a diverse, large-scale source of data.
为了训练我们的模型,我们需要多样化的、大规模的数据源。
Unfortunately, there is no web-scale data source for segmentation;
不幸的是,没有用于分割的网络规模数据源;
to address this, we build a “data engine”, i.e., we iterate between using our efficient model to assist in data collection and using the newly collected data to improve the model.
为了解决这个问题,我们构建了一个“数据引擎”,即我们先使用模型协助数据收集,然后使用新收集的数据改进模型,形成循环。
We introduce each interconnected component next, followed by the dataset we created and the experiments that demonstrate the effectiveness of our approach.
接下来我们介绍每个互连的组件,然后是我们创建的数据集以及证明我们方法有效性的实验。
Task任务 (§2).
In NLP and more recently computer vision, foundation models are a promising development that can perform zero-shot and few-shot learning for new datasets and tasks often by using “prompting” techniques.
在 NLP和最近的计算机视觉中,基础模型是一项有前途的发展,通常可以通过使用“提示”技术对新数据集和任务执行零样本和少样本学习。
Inspired by this line of work, we propose the promptable segmentation task, where the goal is to return a valid segmentation mask given any segmentation prompt (see Fig. 1a).
受这一工作的启发,我们提出了提示分割任务,其目标是在给定任何分割提示的情况下返回有效的分割掩码(见图1a)。
A prompt simply specifies what to segment in an image, e.g., a prompt can include spatial or text information identifying an object.
提示只是指定在图像中分割什么,例如,提示可以是能识别出对象的空间描述或文本描述信息。
The requirement of a valid output mask means that even when a prompt is ambiguous and could refer to multiple objects (for example, a point on a shirt may indicate either the shirt or the person wearing it), the output should be a reasonable mask for at least one of those objects.
有效输出掩码的要求意味着即使提示不明确并且可能引用多个对象 (例如,衬衫上的点可能表示衬衫或穿着它的人), 输出应该是至少其中一个对象的合理掩码。
We use the promptable segmentation task as both a pre-training objective and to solve general downstream segmentation tasks via prompt engineering.
我们使用提示分割任务作为预训练目标,并通过提示工程解决一般下游分割任务。
Model模型 (§3).
The promptable segmentation task and the goal of real-world use impose constraints on the model architecture. In particular, the model must support flexible prompts, needs to compute masks in amortized real-time to allow interactive use,and must be ambiguity-aware.
可提示的分割这一任务设定、现实世界中能实际使用这一目标,都给模型架构带来了约束。 具体来说,模型必须支持灵活的提示,需要近实时地计算掩码以允许交互使用,还需要能够处理歧义。
Surprisingly, we find that a simple design satisfies all three constraints:
惊喜的是,我们发现一个简单的设计能满足所有这三个约束:
a powerful image encoder computes an image embedding,
一个强大的图像编码器,计算图像的embedding,
a prompt encoder embeds prompts,
一个提示编码器,生成提示的embedding,
and then the two information sources are combined in a lightweight mask decoder that predicts segmentation masks.
然后这两个信息作为一个轻量级掩码解码器的输入,它会预测分割掩码。
We refer to this model as the Segment Anything Model, or SAM (see Fig. 1b).
我们将该模型称为SAM(见图 1b)。
By separating SAM into an image encoder and a fast prompt encoder/mask decoder,the same image embedding can be reused (and its cost amortized) with different prompts.
通过将 SAM 分为一个图像编码器和快速的提示编码器/掩码解码器, 相同的图像embedding可以在不同的提示下重复使用,这可以节省图像embedding计算成本。
Given an image embedding, the prompt encoder and mask decoder predict a mask from a prompt in ∼50ms in a web browser.
给定一个图像嵌入,提示编码器和掩码解码器可以在浏览器中耗时~50毫秒根据提示生成掩码。
We focus on point, box, and mask prompts, and also present initial results with free-form text prompts.
我们专注于点、框和蒙版提示,也初步探索了使用格式不限的文本提示。
To make SAM ambiguity-aware, we design it to predict multiple masks for a single prompt allowing SAM to naturally handle ambiguity, such as the shirt vs. person example.
为了使 SAM 能够感知歧义,我们设计对单个prompt输入预测输出多个掩码,从而使 SAM 能够自然地处理歧义,例如衬衫与人的示例。
Data engine 数据引擎(§4).
To achieve strong generalization to new data distributions, we found it necessary to train SAM on a large and diverse set of masks, beyond any segmentation dataset that already exists.
为了实现对新数据分布的强泛化,我们发现有必要在大量且多样化的掩码上训练 SAM,任何现有的分割数据集都无法满足需要。
While a typical approach for foundation models is to obtain data online [82], masks are not naturally abundant and thus we need an alternative strategy.
虽然基础模型的典型方法是在线获取数据[82]【网络上图片和下面的标题就是天然的text-image训练数据】,但掩码数据天然就不丰富,因此我们需要一种替代策略。
Our solution is to build a “data engine”, i.e., we co-develop our model with model-in-the-loop dataset annotation (see Fig. 1c).
我们的解决方案是构建一个“数据引擎”,我们搞了一个model-in-the-loop的数据标注,然后产生的数据再用于优化模型。(见图1c)。
Our data engine has three stages: assisted-manual, semi-automatic, and fully automatic.
我们的数据引擎分为三个阶段:辅助手动、半自动、全自动。
In the first stage, SAM assists annotators in annotating masks, similar to a classic interactive segmentation setup.
在第一阶段,SAM 协助标注员标注掩模,类似于经典的交互式分割初始阶段。
In the second stage, SAM can automatically generate masks for a subset of objects by prompting it with likely object locations and annotators focus on annotating the remaining objects, helping increase mask diversity.
在第二阶段,通过prompt可能的对象位置,SAM可以自动为待识别对象的一部分子集生成掩码,而标注员则专注于标注其余对象,从而帮助增加掩码的多样性。
In the final stage, we prompt SAM with a regular grid of foreground points, yielding on average ∼100 high-quality masks per image.
在最后阶段,我们使用前景点的规则网格提示SAM,平均每张图像能产生出约100个高质量掩码。
Dataset数据集 (§5).
Our final dataset, SA-1B, includes more than 1B masks from 11M licensed and privacy-preserving images (see Fig. 2).
我们的最终数据集 SA-1B 中有 11M 有许可的、保护了隐私图像,以及在这些图像上超过 1B 个掩模(见图 2)。
SA-1B, collected fully automatically using the final stage of our data engine, has 400× more masks than any existing segmentation dataset [66, 44, 117, 60], and as we verify extensively, the masks are of high quality and diversity.
SA-1B 使用我们数据引擎的最后阶段完全自动收集,其掩码比任何现有分割数据集 [66,44,117,60] 多 400 倍,并且正如我们广泛验证的那样,掩码具有高质量和多样性 。
Beyond its use in training SAM to be robust and general, we hope SA-1B becomes a valuable resource for research aiming to build new foundation models.
除了用于训练 SAM 使其变得稳健和通用之外,我们希望 SA-1B 成为构建新基础模型研究的宝贵资源。
Responsible AI (§6).
We study and report on potential fairness concerns and biases when using SA-1B and SAM.
我们研究并报告使用 SA-1B 和 SAM 时潜在的公平问题和偏见。
Images in SA-1B span a geographically and economically diverse set of countries and we found that SAM performs similarly across different groups of people.
SA-1B 中的图像跨越了地理和经济上不同的国家,我们发现 SAM 在不同人群中的表现相似。
Together, we hope this will make our work more equitable for real-world use cases.
我们共同希望这将使我们的工作在现实世界的用例中更加公平。
We provide model and dataset cards in the appendix.
我们在附录中提供模型和数据集卡。
Experiments (§7).
We extensively evaluate SAM.
我们广泛评估 SAM。
First, using a diverse new suite of 23 segmentation datasets, we find that SAM produces high-quality masks from a single foreground point, often only slightly below that of the manually annotated ground truth.
首先,使用了 23 个新的多样的分割数据集,我们发现 SAM 可以从单个前景点生成高质量掩模,通常仅略低于手动标注的ground truth。
Second, we find consistently strong quantitative and qualitative results on a variety of downstream tasks under a zero-shot transfer protocol using prompt engineering, including edge detection, object proposal generation, instance segmentation, and a preliminary exploration of text-to-mask prediction.
其次,我们使用提示工程零样本应用在了各种下游任务上,发现了一致的定量和定性结果,包括边缘检测、对象建议生成、实例分割以及文本到掩模预测的初步探索。
These results suggest that SAM can be used out-of-the-box with prompt engineering to solve a variety of tasks involving object and image distributions beyond SAM’s training data.
这些结果表明,SAM 可以通过提示工程开箱即用,解决SAM训练数据之外的对象和图像分布上的各种任务。
Nevertheless, room for improvement remains, as we discuss in §8.
尽管如此,正如我们在第 8 节中讨论的那样,改进的空间仍然存在。
Release.
We are releasing the SA-1B dataset for research purposes and making SAM available under a permissive open license (Apache 2.0) at https://segment-anything.com. We also showcase SAM’s capabilities with an online demo.
2.Segment Anything Task
We take inspiration from NLP, where the next token prediction task is used for foundation model pre-training and to solve diverse downstream tasks via prompt engineering [10].
我们从 NLP 中获得灵感,next token预测任务用于基础模型预训练,并通过提示工程解决各种下游任务[10]。
To build a foundation model for segmentation, we aim to define a task with analogous capabilities.
为了构建图片分割的基础模型,我们的目标是定义具有类似功能的任务。
Task.
We start by translating the idea of a prompt from NLP to segmentation, where a prompt can be a set of foreground/background points, a rough box or mask, free-form text, or, in general, any information indicating what to segment in an image.
我们首先将提示的概念从 NLP 领域迁移到分割领域,提示可以是一组前景/背景点、一个简单的框或蒙版、格式不限的文本,或者一般来说,任何可以指示图像中要分割的内容的信息。
The promptable segmentation task, then, is to return a valid segmentation mask given any prompt.
那么,可提示的分割任务是在给定任何提示的情况下返回有效的分割掩码。
The requirement of a “valid” mask simply means that even when a prompt is ambiguous and could refer to multiple objects (e.g., recall the shirt vs. person example, and see Fig. 3), the output should be a reasonable mask for at least one of those objects.
“有效”掩码的要求意味着,即使提示不明确并且可能指代多个对象(例如,衬衫与人的示例,请参见图 3),输出也应该是其中至少一个对象的合理掩码。
This requirement is similar to expecting a language model to output a coherent response to an ambiguous prompt.
这个要求类似于期望语言模型对一个含糊的提示输出一个连贯的回应。
We choose this task because it leads to a natural pre-training algorithm and a general method for zero-shot transfer to downstream segmentation tasks via prompting.
我们选择这个任务,因为它引出了一个自然的预训练算法,以及一种通过提示实现对下游分割任务进行零样本迁移的通用方法。
Pre-training.
The promptable segmentation task suggests a natural pre-training algorithm that simulates a sequence of prompts (e.g., points, boxes, masks) for each training sample and compares the model’s mask predictions against the ground truth.
可提示的分割任务采用了一种自然的预训练算法,该算法模拟每个训练样本的一系列提示(例如,点、框、掩模),并将模型的预测输出的掩码与正确标签进行比较。
【怎么判断?在ground truth的标注数据里,应该根据空间计算,能找到和points/boxes/masks重合的object list有哪些,然后模型预测的输出,看是不是在这个object list里?? 但模型要输出的是用精准的边缘线圈出来的object,这是一堆pixel集合,那是和ground truth里的每个可能object计算MSE?找loss最小的那个,算作匹配了?】
We adapt this method from interactive segmentation [109, 70], although unlike interactive segmentation whose aim is to eventually predict a valid mask after enough user input, our aim is to always predict a valid mask for any prompt even when the prompt is ambiguous.
这种方法是对交互式分割[109, 70]方法修改得来的,不像交互式分割的目标是在得到充分的用户输入后最终预测一个有效掩码,我们的目标是始终为任何提示预测有效掩码,即使提示是有语义歧义的。
This ensures that a pre-trained model is effective in use cases that involve ambiguity, including automatic annotation as required by our data engine §4.
这确保了预训练模型在涉及歧义的用例中是有效的,包括我们的数据引擎§4所需的自动注释。
We note that performing well at this task is challenging and requires specialized modeling and training loss choices, which we discuss in §3.
我们注意到,在这项任务中表现良好具有挑战性,需要专门的模型和训练损失选择,我们将在第 3 节中讨论。
Zero-shot transfer.零样本迁移
Intuitively, our pre-training task endows the model with the ability to respond appropriately to any prompt at inference time, and thus downstream tasks can be solved by engineering appropriate prompts.
直观地说,我们的预训练任务赋予模型在推理时对任何提示做出适当响应的能力,因此可以通过设计适当的提示来解决下游任务。
For example, if one has a bounding box detector for cats, cat instance segmentation can be solved by providing the detector’s box output as a prompt to our model.
例如,如果已经有了一个猫猫边界框检测器,则可以将上述检测器输出的box框作为SAM的prompt输入,来解决猫实体分割子问题。
In general, a wide array of practical segmentation tasks can be cast as prompting.
一般来说,大量的实际分割任务都可以转化成一个提示问题。
In addition to automatic dataset labeling, we explore five diverse example tasks in our experiments in §7.
除了自动数据集标记之外,我们还在第 7 节的实验中探索了五个不同的示例任务。
Related tasks.相关任务
Segmentation is a broad field: there’s interactive segmentation [57, 109], edge detection [3], super pixelization [85], object proposal generation [2], foreground segmentation [94], semantic segmentation [90], instance segmentation [66], panoptic segmentation [59], etc.
图片分割是一个广泛的领域: 有交互式分割 [57, 109]、边缘检测 [3]、超像素化 [85]、目标提案生成 [2]、前景分割 [94]、语义分割 [90]、实例分割 [66]、全景分割 [59] 等。
The goal of our promptable segmentation task is to produce a broadly capable model that can adapt to many (though not all) existing and new segmentation tasks via prompt engineering.
我们的可提示分割任务的目标是产生一个具备通用能力的模型,可以通过提示工程适应许多(尽管不是全部)现有的和新的分割任务。
This capability is a form of task generalization [26].
这种能力是一种任务泛化形式[26]。
Note that this is different than previous work on multi-task segmentation systems.
请注意,这与多任务分割系统不同。
In a multi-task system, a single model performs a fixed set of tasks, e.g., joint semantic, instance, and panoptic segmentation [114, 19, 54], but the training and test tasks are the same.
在多任务系统中,单个模型执行一组固定的任务,例如联合语义、实例和全景分割[114,19,54],但训练和测试任务是相同的。
An important distinction in our work is that a model trained for promptable segmentation can perform a new, different task at inference time by acting as a component in a larger system,
我们工作中的一个重要区别是,可提示分割的模型可以在一个更大系统中充当组件,以在推理时执行新的、不同的任务,
e.g., to perform instance segmentation, a promptable segmentation model is combined with an existing object detector.
例如,可提示的分割模型与已有的对象检测器相结合,可联合实现实例分割功能。
Discussion.讨论
Prompting and composition are powerful tools that enable a single model to be used in extensible ways, potentially to accomplish tasks unknown at the time of model design.
提示和组合是强大的工具,使单个模型能够以可扩展的方式使用,有可能完成模型设计时未知的任务。
This approach is analogous to how other foundation models are used, e.g., how CLIP [82] is the text-image alignment component of the DALL·E [83] image generation system.
这种方法类似于其他基础模型的使用方式,例如 CLIP [82] 是 DALL·E [83] 图像生成系统的文本图像对齐组件。
We anticipate that composable system design, powered by techniques such as prompt engineering, will enable a wider variety of applications than systems trained specifically for a fixed set of tasks.
我们预计,由提示工程等技术支持的可组合系统设计将比专门为一组固定任务训练的系统实现更广泛的应用。
It’s also interesting to compare promptable and interactive segmentation through the lens of composition:
通过组合的角度来比较可提示分割和交互式分割也很有趣:
while interactive segmentation models are designed with human users in mind, a model trained for promptable segmentation can also be composed into a larger algorithmic system as we will demonstrate.
虽然交互式分割模型在设计时考虑了人类用户,但针对可提示分割进行训练的模型也可以组成更大的算法系统,正如我们将演示的那样。
3.Segment Anything Model
We next describe the Segment Anything Model (SAM) for promptable segmentation.
SAM has three components, illustrated in Fig. 4: an image encoder, a flexible prompt encoder, and a fast mask decoder.
SAM有三个组件,如图 4 所示:图像编码器、prompt编码器、掩码解码器。
We build on Transformer vision models [14, 33, 20, 62] with specific tradeoffs for (amortized) real-time performance.
我们在ViT [14,33,20,62] 基础上,针对(摊销的)实时性能做了tradeoff。
We describe these components at a high-level here, with details in §A.
我们在这里high-level地描述了这些组件,细节见§A。
Image encoder.图像编码器
Motivated by scalability and powerful pre-training methods, we use an MAE [47] pre-trained Vision Transformer (ViT) [33] minimally adapted to process high resolution inputs [62].
得益于可扩展性和强大的预训练方法,我们使用了一个MAE [47] loss 预训练的ViT[33],仅做了最小的调整以处理高分辨率输入 [62]。
The image encoder runs once per image and can be applied prior to prompting the model.
图像编码器每张图像运行一次,可以在模型提示输入前就跑完。
In general, the image encoder can be any network that outputs a C×H×W image embedding.
通常来说,图像编码器可以是任何输出大小为 C×H×W 的图像embedding模型
Specifically a ViT-H/16 with 14×14 windowed attention and four equally-spaced global attention blocks, follow-ing [62]. The image encoder’s output is a 16× downscaled embedding of the input image. Since our runtime goal is to process each prompt in real-time, we can afford a high num- ber of image encoder FLOPs because they are computed only once per image, not per prompt.
此处用了ViT。具体来说,就是一个 ViT-H/16,具有 14×14 的窗口注意力和四个等间距的全局注意力块,参照 [62]。图像编码器的输出是输入图像下采样 16 倍的embedding。由于我们的运行目标是实时处理每个提示,我们可以承担图像编码器较高的FLOPs开销,因为每张图像只计算一次,而不是每个提示都计算。
Following standard practices (e.g., [40]), we use an input resolution of 1024×1024 obtained by rescaling the image and padding the shorter side. The image embedding is therefore 64×64. To reduce the channel dimension, following [62], we use a 1×1 convolution to get to 256 channels, followed by a 3×3 convolution also with 256 channels. Each convolution is followed by a layer normalization [4].
按照标准做法(例如 [40]),我们使用通过重新缩放图像并填充较短边得到的 1024×1024 的输入分辨率。因此,图像embedding的大小为 64×64。为了减少通道维度,按照 [62],我们先使用一个 1×1 的卷积将通道数降至 256,然后是一个 3×3 的卷积,同样具有 256 个通道。每个卷积后面都跟着一个层归一化 [4]。
# 缩放图像并填充较短边得到的 1024×1024 的输入分辨率
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
np.array(resize(to_pil_image(image), target_size))
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# Pad
h, w = x.shape[-2:]
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
# ViT将图片切成16X16大小的patch,每个patch类比语言模型中的token,然后每个patch输入Transformer后会embedding成一个D维的tensor,在ViT-H中是1280维。ViT-H的输出是(1024/16)×(1024/16)×1280=64×64×1280的
(patch_embed): PatchEmbed(
(proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
)
# 然后1×1卷积处理后变成64×64×256,再跟着3×3卷积最后还是64*64*256,是stride=1且有padding的
(0): Conv2d(1280, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): LayerNorm2d()
(2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(3): LayerNorm2d()
Prompt encoder提示编码器.
We consider two sets of prompts: sparse (points, boxes, text) and dense (masks).
我们考虑两种提示:稀疏(点、框、文本)和密集(掩模)。
Sparse prompts are mapped to 256-dimensional vectorial embeddings as follows.
稀疏提示被映射为256维的向量嵌入,具体如下。
A point is represented as the sum of a positional encoding [95] of the point’s location and one of two learned embeddings that in- dicate if the point is either in the foreground or background.
一个点被表示为该点位置的位置编码 [95] 与两个学习到的嵌入之一的和,这两个嵌入指示该点是在前景还是背景。
A box is represented by an embedding pair: (1) the posi- tional encoding of its top-left corner summed with a learned embedding representing “top-left corner” and (2) the same structure but using a learned embedding indicating “bottom- right corner”.
一个框被表示为一对嵌入:(1) 其左上角的位置信息编码与一个表示“左上角”的学习嵌入的和,以及 (2) 相同的结构但使用一个表示“右下角”的学习嵌入。
即,框是一个embedding对:< 左上角那个点的位置编码+一个embedding表征”左上角” , 右下角那个点的位置编码+一个embedding表征”右下角” >
Finally, to represent free-form text we use the text encoder from CLIP [82] (any text encoder is possible in general). We focus on geometric prompts for the remainder of this section and discuss text prompts in depth in §D.5.
最后,为了表示自由形式的文本,我们使用来自 CLIP [82] 的文本编码器(一般而言,任何文本编码器都可以)。本节余下部分我们专注于几何提示,并在附录 D.5 中深入讨论文本提示。
【TODO:prompt的embedding也是学出来的,通过啥loss来训,怎么评判prompt的embedding的好坏的???一直不太能理解这个学习的过程】
# 稀疏编码的模型参数
(pe_layer): PositionEmbeddingRandom()
(point_embeddings): ModuleList(
(0-3): 4 x Embedding(1, 256)
)
(not_a_point_embed): Embedding(1, 256)
# 点的编码,比如一个点的横纵坐标位置为(x,y),经过如下计算,x和y会分别得到两个256维的向量
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
# 然后再加上
points = points + 0.5 # Shift to center of pixel
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
Dense prompts (i.e., masks) are embedded using convolutions and summed element-wise with the image embedding.
密集提示(即掩码)使用卷积来做embedding,并按元素sum上了图像embedding。
Dense prompts (i.e., masks) have a spatial correspondence with the image.
密集提示(即掩码)与图像具有空间对应关系。
We input masks at a 4× lower resolution than the input image, then downscale an additional 4× using two 2×2, stride-2 convolutions with output channels 4 and 16, respectively. A final 1×1 convolution maps the channel dimension to 256.
我们以比输入图像低 4× 的分辨率输入掩码,然后使用两个大小为 2×2、步幅为 2 的卷积层,输出通道数分别为 4 和 16,进一步下采样 4×。最后,一个 1×1 卷积将通道维度映射到 256。
Each layer is separated by GELU activations [50] and layer normalization. The mask and image embedding are then added element-wise.
每一层之间都由 GELU 激活函数 [50] 和层归一化分隔。然后,掩码与图像嵌入逐元素相加。
If there is no mask prompt, a learned embedding representing “no mask” is added to each image embedding location.
如果没有掩码提示,则在每个图像嵌入位置添加一个表示“无掩码”的学习嵌入。
(mask_downscaling): Sequential(
(0): Conv2d(1, 4, kernel_size=(2, 2), stride=(2, 2))
(1): LayerNorm2d()
(2): GELU(approximate='none')
(3): Conv2d(4, 16, kernel_size=(2, 2), stride=(2, 2))
(4): LayerNorm2d()
(5): GELU(approximate='none')
(6): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))
)
(no_mask_embed): Embedding(1, 256)
mask_downscaling = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4),
activation(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans),
activation(),
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)
Mask decoder.掩码解码器
To combine these inputs, we take inspiration from Transformer segmentation models [14, 20] and modify a standard Transformer decoder [103].
为了组合这些输入,我们从 Transformer 分割模型 [14, 20] 中获得灵感,修改了一个标准的 Transformer 解码器 [103]。
Before applying our decoder, we first insert into the set of prompt embeddings a learned output token embedding that will be used at the decoder’s output, analogous to the [class] token in [33].
在应用我们的解码器之前,我们首先在提示嵌入集合中插入一个学习到的输出标记嵌入,该嵌入将在解码器的输出中使用,类似于 [33] 中的 [class] 标记。
For simplicity, we refer to these embeddings (not including the image embedding) collectively as “tokens”.
为简单起见,我们将这些embeddings(不包括图像的)统称为“标记”。
Our decoder design is shown in Fig. 14. Each decoder layer performs 4 steps: (1) self-attention on the tokens, (2) cross-attention from tokens (as queries) to the image em- bedding, (3) a point-wise MLP updates each token, and (4) cross-attention from the image embedding (as queries) to tokens.
每个解码器层都有如下 4 个步骤:
(1) token上的自注意力【prompt之间为啥要看相似度?】
(2) 从token(作为query)到图像embedding的交叉注意力,
(3) 逐点 MLP 更新每个 token,
(4) 从图像embedding(作为query)到token的交叉注意力。
This last step updates the image embedding with prompt information.
最后一步用提示信息更新了图像embedding。
During cross-attention, the image embedding is treated as a set of 642 256-dimensional vectors.
在交叉注意力过程中,图像embedding被视为一组 642 个 256 维向量。
Each self/cross-attention and MLP has a residual connection [49], layer normalization, and a dropout [93] of 0.1 at training.
每个自/交叉注意力和 MLP 在训练时都有一个残差连接 [49]、层归一化和 0.1 的 dropout [93]。
The next decoder layer takes the updated tokens and the updated image embedding from the previous layer.
下一个解码器层从上一层获取更新的token和更新的图像embedding。
We use a two-layer decoder.
我们使用两层解码器。
To ensure the decoder has access to critical geometric information the positional encodings are added to the image embedding whenever they participate in an attention layer.
为了确保解码器能够访问关键的几何信息,位置编码在每次参与注意力层时都会添加到图像嵌入中。
Additionally, the entire original prompt tokens (including their positional encodings) are readded to the updated tokens whenever they participate in an attention layer.
此外,每当参与注意力层时,整个原始prompt tokens(包括其位置编码)都会重新添加到更新的tokens中。
This allows for a strong dependence on both the prompt token’s geometric location and type.
这使得对prompt tokens的几何位置和类型形成强烈依赖。
After running the decoder, we upsample the updated image embedding by 4× with two transposed convolutional layers (now it’s downscaled 4× relative to the input image).
运行解码器后,我们使用两个转置卷积层将更新的图像嵌入上采样 4 倍(现在相对于输入图像下采样了 4 倍)。
Then, the tokens attend once more to the image embedding and we pass the updated output token embedding to a small 3-layer MLP that outputs a vector matching the channel dimension of the upscaled image embedding.
然后,tokens再次对图像embedding做注意力计算,我们将更新后的输出token embedding传递给一个小型的三层 MLP,它的输出向量与上采样后的图像embedding的维度一样。
Finally, we predict a mask with a spatially point-wise product between the upscaled image embedding and the MLP’s output.
最后,我们通过上采样后的图像嵌入和 MLP 输出之间的空间逐点乘积来预测掩码。
The transformer uses an embedding dimension of 256. The transformer MLP blocks have a large internal dimension of 2048, but the MLP is applied only to the prompt tokens for which there are relatively few (rarely greater than 20).
Transformer 使用了 256 的嵌入维度。Transformer 的 MLP 块具有更长的 2048 长度的内部维度,但 MLP 仅应用于prompt tokens,而prompt tokens数量相对较少(很少超过 20 个)。
However, in cross-attention layers where we have a 64×64 image embedding, we reduce the channel dimension of the queries, keys, and values by 2× to 128 for computational efficiency. All attention layers use 8 heads.
然而,在交叉注意力层中,我们有 64×64 个image embedding,为了计算效率,我们将查询、键和值的通道维度减少一半至 128。所有注意力层使用 8 个头。
The transposed convolutions used to upscale the output image embedding are 2×2, stride 2 with output channel di- mensions of 64 and 32 and have GELU activations. They are separated by layer normalization.
用于上采样输出图像嵌入的转置卷积为 2×2,步幅为 2,输出通道维度为 64 和 32,且具有 GELU 激活函数。它们之间由层归一化分隔。
Resolving ambiguity.解决歧义
With one output, the model will average multiple valid masks if given an ambiguous prompt.
输出只有一个,如果提示是不明确的,那么模型将对多个有效掩码进行平均。
To address this, we modify the model to predict multiple output masks for a single prompt (see Fig. 3). 我们修改了模型,对单个提示预测多个输出掩码(见图 3)。 We found 3 mask outputs is sufficient to address most common cases (nested masks are often at most three deep: whole, part, and subpart). 我们发现 3 个掩码输出足以解决最常见的情况(嵌套掩码通常最多三层深度:整体、部分和子部分)。 During training, we backprop only the minimum loss [15, 45, 64] over masks. 在训练期间,我们仅反向传播掩模上的最小损失[15,45,64]。 To rank masks, the model predicts a confidence score (i.e., estimated IoU) for each mask. 为了对掩模进行排名,模型预测每个掩模的置信度得分(即估计的IoU)。
As described, a single input prompt may be ambiguous in the sense that it corresponds to multiple valid masks, and the model will learn to average over these masks.
如前所述,单个输入提示可能存在歧义,因为它可能对应多个有效的掩码,模型将会学习对这些掩码进行平均。
We eliminate this problem with a simple modification: instead of predicting a single mask, we use a small number of output tokens and predict multiple masks simultaneously.
我们通过一个简单的修改来消除这个问题:我们不再预测单个掩码,而是使用少量的输出tokens,同时预测多个掩码。
By default we predict three masks, since we observe that three layers (whole, part, and subpart) are often enough to describe nested masks. During training, we compute the loss (described shortly) between the ground truth and each of the predicted masks, but only backpropagate from the lowest loss.
默认情况下,我们预测三个掩码,因为我们观察到三层(整体、部分和子部分)通常足以描述嵌套的掩码。在训练过程中,我们计算每个预测掩码与真实掩码之间的损失(稍后会描述),但只对最小损失进行反向传播。
This is a common technique used for models with multiple outputs [15, 45, 64].
这是一种常见的技术,用于具有多个输出的模型 [15, 45, 64]。
For use in applications, we’d like to rank predicted masks, so we add a small head (operating on an additional output token) that estimates the IoU between each predicted mask and the object it covers.
为了能实际应用,我们希望对预测输出的这些掩码们进行排序,因此我们添加了一个小头(作用于一个额外的输出token),用于估计每个预测掩码与其覆盖对象之间的 IoU(交并比)。
Ambiguity is much rarer with multiple prompts and the three output masks will usually become similar.
在有多个提示的情况下,歧义要少得多,三个输出掩码通常会变得相似。
To minimize computation of degenerate losses at training and ensure the single unambiguous mask receives a regular gradient signal, we only predict a single mask when more than one prompt is given.
为了在训练中减小退化损失的计算开销,以及确保单个明确的掩码能接收到正常的梯度信号,当收到多个提示时,我们只预测一个掩码。
This is accomplished by adding a fourth output token for an additional mask prediction. This fourth mask is never returned for a single prompt and is the only mask returned for multiple prompts.
这是通过为额外的掩码预测添加第四个输出token来实现的。这个第四个掩码在单个提示的情况下永远不会返回,并且在多个提示的情况下是唯一返回的掩码。
Efficiency.效率
The overall model design is largely motivated by efficiency.Given a precomputed image embedding, the prompt encoder and mask decoder run in a web browser, on CPU, in ∼50ms.This runtime performance enables seamless, real-time interactive prompting of our model.
模型的整体设计很大程度上是由效率驱动的。 给定预计算好的图像embedding,提示编码器和掩码解码器在 Web 浏览器中的 CPU 上运行,耗时约 50 毫秒。 这种运行时性能可以为我们的模型提供无缝、实时的交互式提示。
<font color=#008000>【Focal Loss: 提出于2017年的论文”FOCAL LOSS FOR DENSE OBJECT DETECTION”中,由Facebook AI团队的Tsung-Yi Lin和其他人开发。Focal Loss是为了解决目标检测问题中类别不平衡问题。在训练目标检测模型时,大部分的区域都是背景,正样本(感兴趣的对象)只占很小的一部分,这就导致了严重的类别不平衡问题。Focal Loss通过增大难以分类样本的权重,降低易于分类样本的权重,来聚焦于训练难样本,从而解决这个问题。】</font> <font color=#008000>【Dice Loss: Dice Loss通常用于处理医学图像分割等问题,其基于Dice系数,这是一种用于评估图像分割模型性能的指标。Dice系数能衡量预测的分割区域和实际分割区域的空间重叠度。Dice Loss就是以1减去Dice系数得到的,值域为0到1,值越小说明预测的分割区域和实际分割区域的一致性越好。Dice Loss有助于处理分割任务中的类别不平衡问题。】</font> We train for the promptable segmentation task using a mixture of geometric prompts (for text prompts see §7.5). 我们使用混合几何提示来训练可提示分割任务(有关文本提示,请参见第 7.5 节)。 Following [92, 37], we simulate an interactive setup by randomly sampling prompts in 11 rounds per mask, allowing SAM to integrate seamlessly into our data engine. 参考[92, 37],我们模仿了交互式初始化,通过为每个掩码进行 11 轮的随机采样提示,从而使 SAM 能够无缝集成到我们的数据引擎中。
Losses. We supervise mask prediction with a linear combination of focal loss [65] and dice loss [73] in a 20:1 ratio of focal loss to dice loss, following [20, 14]. Unlike [20, 14], we observe that auxiliary deep supervision after each de- coder layer is unhelpful. The IoU prediction head is trained with mean-square-error loss between the IoU prediction and the predicted mask’s IoU with the ground truth mask. It is added to the mask loss with a constant scaling factor of 1.0.
Training algorithm. Following recent approaches [92, 37], we simulate an interactive segmentation setup during train- ing. First, with equal probability either a foreground point or bounding box is selected randomly for the target mask. Points are sampled uniformly from the ground truth mask. Boxes are taken as the ground truth mask’s bounding box, with random noise added in each coordinate with standard deviation equal to 10% of the box sidelength, to a maxi- mum of 20 pixels. This noise profile is a reasonable com- promise between applications like instance segmentation, which produce a tight box around the target object, and in- teractive segmentation, where a user may draw a loose box. After making a prediction from this first prompt, subse- quent points are selected uniformly from the error region between the previous mask prediction and the ground truth mask. Each new point is foreground or background if the er- ror region is a false negative or false positive, respectively. We also supply the mask prediction from the previous it- eration as an additional prompt to our model. To provide the next iteration with maximal information, we supply the unthresholded mask logits instead of the binarized mask. When multiple masks are returned, the mask passed to the next iteration and used to sample the next point is the one with the highest predicted IoU.
We find diminishing returns after 8 iteratively sampled points (we have tested up to 16). Additionally, to encour- age the model to benefit from the supplied mask, we also use two more iterations where no additional points are sam- pled. One of these iterations is randomly inserted among the 8 iteratively sampled points, and the other is always at the end. This gives 11 total iterations: one sampled initial in- put prompt, 8 iteratively sampled points, and two iterations where no new external information is supplied to the model so it can learn to refine its own mask predictions. We note that using a relatively large number of iterations is possible because our lightweight mask decoder requires less than 1% of the image encoder’s compute and, therefore, each itera- tion adds only a small overhead. This is unlike previous interactive methods that perform only one or a few interac- tive steps per optimizer update [70, 9, 37, 92].
Training recipe. We use the AdamW [68] optimizer (β1 = 0.9, β2 = 0.999) and a linear learning rate warmup [42] for 250 iterations and a step-wise learning rate decay schedule. The initial learning rate (lr), after warmup, is 8e−4. We train for 90k iterations (∼2 SA-1B epochs) and decrease the lr by a factor of 10 at 60k iterations and again at 86666 it- erations. The batch size is 256 images. To regularize SAM, we set weight decay (wd) to 0.1 and apply drop path [53] (dp) with a rate of 0.4. We use a layer-wise learning rate decay [5] (ld) of 0.8. No data augmentation is applied. We initialize SAM from an MAE [47] pre-trained ViT-H. We distribute training across 256 GPUs, due to the large image encoder and 1024×1024 input size. To limit GPU memory usage, we train with up to 64 randomly sampled masks per GPU. Additionally, we find that lightly filtering SA-1B masks to discard any that cover more than 90% of the image qualitatively improves results. For ablations and others variations on training (e.g., text- to-mask §D.5), we deviate from the default recipe above as follows. When training with data from the first and sec- ond data engine stages only, we augment the input with large-scale jitter [40] with a scale range of [0.1, 2.0]. In- tuitively, data augmentation may be helpful when training data is more limited. To train ViT-B and ViT-L, we use 180k iterations with batch size 128 distributed across 128 GPUs. We set lr = 8e−4/4e−4, ld = 0.6/0.8, wd = 0.1, and dp = 0.6/0.4 for ViT-B/L, respectively.
SAM的模型架构
Sam(
(image_encoder): ImageEncoderViT(
(patch_embed): PatchEmbed(
(proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
)
(blocks): ModuleList(
(0-31): 32 x Block(
(norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=1280, out_features=3840, bias=True)
(proj): Linear(in_features=1280, out_features=1280, bias=True)
)
(norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(lin1): Linear(in_features=1280, out_features=5120, bias=True)
(lin2): Linear(in_features=5120, out_features=1280, bias=True)
(act): GELU(approximate='none')
)
)
)
(neck): Sequential(
(0): Conv2d(1280, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): LayerNorm2d()
(2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(3): LayerNorm2d()
)
)
(prompt_encoder): PromptEncoder(
(pe_layer): PositionEmbeddingRandom()
(point_embeddings): ModuleList(
(0-3): 4 x Embedding(1, 256)
)
(not_a_point_embed): Embedding(1, 256)
(mask_downscaling): Sequential(
(0): Conv2d(1, 4, kernel_size=(2, 2), stride=(2, 2))
(1): LayerNorm2d()
(2): GELU(approximate='none')
(3): Conv2d(4, 16, kernel_size=(2, 2), stride=(2, 2))
(4): LayerNorm2d()
(5): GELU(approximate='none')
(6): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))
)
(no_mask_embed): Embedding(1, 256)
)
(mask_decoder): MaskDecoder(
(transformer): TwoWayTransformer(
(layers): ModuleList(
(0-1): 2 x TwoWayAttentionBlock(
(self_attn): Attention(
(q_proj): Linear(in_features=256, out_features=256, bias=True)
(k_proj): Linear(in_features=256, out_features=256, bias=True)
(v_proj): Linear(in_features=256, out_features=256, bias=True)
(out_proj): Linear(in_features=256, out_features=256, bias=True)
)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(cross_attn_token_to_image): Attention(
(q_proj): Linear(in_features=256, out_features=128, bias=True)
(k_proj): Linear(in_features=256, out_features=128, bias=True)
(v_proj): Linear(in_features=256, out_features=128, bias=True)
(out_proj): Linear(in_features=128, out_features=256, bias=True)
)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(mlp): MLPBlock(
(lin1): Linear(in_features=256, out_features=2048, bias=True)
(lin2): Linear(in_features=2048, out_features=256, bias=True)
(act): ReLU()
)
(norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(cross_attn_image_to_token): Attention(
(q_proj): Linear(in_features=256, out_features=128, bias=True)
(k_proj): Linear(in_features=256, out_features=128, bias=True)
(v_proj): Linear(in_features=256, out_features=128, bias=True)
(out_proj): Linear(in_features=128, out_features=256, bias=True)
)
)
)
(final_attn_token_to_image): Attention(
(q_proj): Linear(in_features=256, out_features=128, bias=True)
(k_proj): Linear(in_features=256, out_features=128, bias=True)
(v_proj): Linear(in_features=256, out_features=128, bias=True)
(out_proj): Linear(in_features=128, out_features=256, bias=True)
)
(norm_final_attn): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
)
(iou_token): Embedding(1, 256)
(mask_tokens): Embedding(4, 256)
(output_upscaling): Sequential(
(0): ConvTranspose2d(256, 64, kernel_size=(2, 2), stride=(2, 2))
(1): LayerNorm2d()
(2): GELU(approximate='none')
(3): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
(4): GELU(approximate='none')
)
(output_hypernetworks_mlps): ModuleList(
(0-3): 4 x MLP(
(layers): ModuleList(
(0-1): 2 x Linear(in_features=256, out_features=256, bias=True)
(2): Linear(in_features=256, out_features=32, bias=True)
)
)
)
(iou_prediction_head): MLP(
(layers): ModuleList(
(0-1): 2 x Linear(in_features=256, out_features=256, bias=True)
(2): Linear(in_features=256, out_features=4, bias=True)
)
)
)
)