Telegram Group & Telegram Channel
Разбираемся в генеративных моделях: Flow matching

Помните, в прошлый раз мы разбирали DDPM, где нужно было делать 1000 шагов для генерации? А что если я скажу, что можно сделать всё то же самое, но в разы проще и быстрее?

Сегодня поговорим про flow matching в его самой простой форме - linear interpolation. Если DDPM показался вам сложным, то тут вы офигеете насколько всё просто.

В чём основная идея? Вместо того чтобы учить модель убирать шум пошагово (как в DDPM), мы учим её находить прямой путь от шума к картинке. Да-да, просто рисуем линию из точки А в точку Б!

Как это работает:

1. Берём шум и настоящую картинку
2. Случайно выбираем точку между ними (это наше t)
3. Просим модель предсказать в какую сторону двигаться из этой точки

И всё! Вот честно - это весь алгоритм. Смотрите какой простой код для обучения:

def train_step(self, x0):
batch_size = len(x0)
z = torch.randn(batch_size, self.dim).to(self.device)
t = torch.rand(batch_size, 1).to(self.device)
xt = (1 - t) * z + t * x0 # линейная интерполяция между шумом и картинкой
pred_field = self.vector_field(xt, t)

true_field = x0 - z # вот оно - направление от шума к картинке
loss = F.mse_loss(pred_field, true_field)
return loss # возвращаем loss, а не x


А генерация ещё проще - просто идём маленькими шажками в нужном направлении:

def sample(self, batch_size=64, steps=100):
dt = 1.0 / steps
x = torch.randn(batch_size, self.dim).to(self.device)
for i in range(steps):
t = torch.ones(batch_size, 1).to(self.device) * i * dt
v = self.vector_field(x, t)
x = x + dt * v
return x


А теперь самое интересное - то что мы тут делаем, по сути решаем обычный дифур!

Наш vector_field это просто производная dx/dt, а в sample мы используем метод Эйлера для решения этого дифура. И тут открывается целое поле для экспериментов - можно использовать любые солверы: Рунге-Кутту, multistep методы и прочие штуки из мира численных методов.

В общем берите любой солвер из scipy.integrate и вперёд! Некоторые из них позволят ещё сильнее уменьшить количество шагов при генерации.

Главные преимущества по сравнению с DDPM:

- Не нужно возиться с расписаниями шума
- Процесс полностью детерминированный (мы же просто решаем дифур!)
- Генерация работает в разы быстрее
- Код настолько простой, что его можно написать за 5 минут

Я сам офигел когда первый раз это запустил - на многих задачах качество получается сравнимое с DDPM, а кода в три раза меньше.

Единственный небольшой минус - модель иногда бывает менее стабильной при обучении, т.к. нет стохастичности как в DDPM. Но это решается правильным подбором learning rate.

Flow Matching Guide and Code: https://arxiv.org/pdf/2412.06264



group-telegram.com/neural_cell/264
Create:
Last Update:

Разбираемся в генеративных моделях: Flow matching

Помните, в прошлый раз мы разбирали DDPM, где нужно было делать 1000 шагов для генерации? А что если я скажу, что можно сделать всё то же самое, но в разы проще и быстрее?

Сегодня поговорим про flow matching в его самой простой форме - linear interpolation. Если DDPM показался вам сложным, то тут вы офигеете насколько всё просто.

В чём основная идея? Вместо того чтобы учить модель убирать шум пошагово (как в DDPM), мы учим её находить прямой путь от шума к картинке. Да-да, просто рисуем линию из точки А в точку Б!

Как это работает:

1. Берём шум и настоящую картинку
2. Случайно выбираем точку между ними (это наше t)
3. Просим модель предсказать в какую сторону двигаться из этой точки

И всё! Вот честно - это весь алгоритм. Смотрите какой простой код для обучения:

def train_step(self, x0):
batch_size = len(x0)
z = torch.randn(batch_size, self.dim).to(self.device)
t = torch.rand(batch_size, 1).to(self.device)
xt = (1 - t) * z + t * x0 # линейная интерполяция между шумом и картинкой
pred_field = self.vector_field(xt, t)

true_field = x0 - z # вот оно - направление от шума к картинке
loss = F.mse_loss(pred_field, true_field)
return loss # возвращаем loss, а не x


А генерация ещё проще - просто идём маленькими шажками в нужном направлении:

def sample(self, batch_size=64, steps=100):
dt = 1.0 / steps
x = torch.randn(batch_size, self.dim).to(self.device)
for i in range(steps):
t = torch.ones(batch_size, 1).to(self.device) * i * dt
v = self.vector_field(x, t)
x = x + dt * v
return x


А теперь самое интересное - то что мы тут делаем, по сути решаем обычный дифур!

Наш vector_field это просто производная dx/dt, а в sample мы используем метод Эйлера для решения этого дифура. И тут открывается целое поле для экспериментов - можно использовать любые солверы: Рунге-Кутту, multistep методы и прочие штуки из мира численных методов.

В общем берите любой солвер из scipy.integrate и вперёд! Некоторые из них позволят ещё сильнее уменьшить количество шагов при генерации.

Главные преимущества по сравнению с DDPM:

- Не нужно возиться с расписаниями шума
- Процесс полностью детерминированный (мы же просто решаем дифур!)
- Генерация работает в разы быстрее
- Код настолько простой, что его можно написать за 5 минут

Я сам офигел когда первый раз это запустил - на многих задачах качество получается сравнимое с DDPM, а кода в три раза меньше.

Единственный небольшой минус - модель иногда бывает менее стабильной при обучении, т.к. нет стохастичности как в DDPM. Но это решается правильным подбором learning rate.

Flow Matching Guide and Code: https://arxiv.org/pdf/2412.06264

BY the last neural cell


Warning: Undefined variable $i in /var/www/group-telegram/post.php on line 260

Share with your friend now:
group-telegram.com/neural_cell/264

View MORE
Open in Telegram


Telegram | DID YOU KNOW?

Date: |

Unlike Silicon Valley giants such as Facebook and Twitter, which run very public anti-disinformation programs, Brooking said: "Telegram is famously lax or absent in its content moderation policy." Telegram does offer end-to-end encrypted communications through Secret Chats, but this is not the default setting. Standard conversations use the MTProto method, enabling server-client encryption but with them stored on the server for ease-of-access. This makes using Telegram across multiple devices simple, but also means that the regular Telegram chats you’re having with folks are not as secure as you may believe. The regulator said it has been undertaking several campaigns to educate the investors to be vigilant while taking investment decisions based on stock tips. At the start of 2018, the company attempted to launch an Initial Coin Offering (ICO) which would enable it to enable payments (and earn the cash that comes from doing so). The initial signals were promising, especially given Telegram’s user base is already fairly crypto-savvy. It raised an initial tranche of cash – worth more than a billion dollars – to help develop the coin before opening sales to the public. Unfortunately, third-party sales of coins bought in those initial fundraising rounds raised the ire of the SEC, which brought the hammer down on the whole operation. In 2020, officials ordered Telegram to pay a fine of $18.5 million and hand back much of the cash that it had raised. "He has kind of an old-school cyber-libertarian world view where technology is there to set you free," Maréchal said.
from in


Telegram the last neural cell
FROM American