Telegram Group & Telegram Channel
Разбираемся в генеративных моделях: Flow matching

Помните, в прошлый раз мы разбирали DDPM, где нужно было делать 1000 шагов для генерации? А что если я скажу, что можно сделать всё то же самое, но в разы проще и быстрее?

Сегодня поговорим про flow matching в его самой простой форме - linear interpolation. Если DDPM показался вам сложным, то тут вы офигеете насколько всё просто.

В чём основная идея? Вместо того чтобы учить модель убирать шум пошагово (как в DDPM), мы учим её находить прямой путь от шума к картинке. Да-да, просто рисуем линию из точки А в точку Б!

Как это работает:

1. Берём шум и настоящую картинку
2. Случайно выбираем точку между ними (это наше t)
3. Просим модель предсказать в какую сторону двигаться из этой точки

И всё! Вот честно - это весь алгоритм. Смотрите какой простой код для обучения:

def train_step(self, x0):
batch_size = len(x0)
z = torch.randn(batch_size, self.dim).to(self.device)
t = torch.rand(batch_size, 1).to(self.device)
xt = (1 - t) * z + t * x0 # линейная интерполяция между шумом и картинкой
pred_field = self.vector_field(xt, t)

true_field = x0 - z # вот оно - направление от шума к картинке
loss = F.mse_loss(pred_field, true_field)
return loss # возвращаем loss, а не x


А генерация ещё проще - просто идём маленькими шажками в нужном направлении:

def sample(self, batch_size=64, steps=100):
dt = 1.0 / steps
x = torch.randn(batch_size, self.dim).to(self.device)
for i in range(steps):
t = torch.ones(batch_size, 1).to(self.device) * i * dt
v = self.vector_field(x, t)
x = x + dt * v
return x


А теперь самое интересное - то что мы тут делаем, по сути решаем обычный дифур!

Наш vector_field это просто производная dx/dt, а в sample мы используем метод Эйлера для решения этого дифура. И тут открывается целое поле для экспериментов - можно использовать любые солверы: Рунге-Кутту, multistep методы и прочие штуки из мира численных методов.

В общем берите любой солвер из scipy.integrate и вперёд! Некоторые из них позволят ещё сильнее уменьшить количество шагов при генерации.

Главные преимущества по сравнению с DDPM:

- Не нужно возиться с расписаниями шума
- Процесс полностью детерминированный (мы же просто решаем дифур!)
- Генерация работает в разы быстрее
- Код настолько простой, что его можно написать за 5 минут

Я сам офигел когда первый раз это запустил - на многих задачах качество получается сравнимое с DDPM, а кода в три раза меньше.

Единственный небольшой минус - модель иногда бывает менее стабильной при обучении, т.к. нет стохастичности как в DDPM. Но это решается правильным подбором learning rate.

Flow Matching Guide and Code: https://arxiv.org/pdf/2412.06264



group-telegram.com/neural_cell/264
Create:
Last Update:

Разбираемся в генеративных моделях: Flow matching

Помните, в прошлый раз мы разбирали DDPM, где нужно было делать 1000 шагов для генерации? А что если я скажу, что можно сделать всё то же самое, но в разы проще и быстрее?

Сегодня поговорим про flow matching в его самой простой форме - linear interpolation. Если DDPM показался вам сложным, то тут вы офигеете насколько всё просто.

В чём основная идея? Вместо того чтобы учить модель убирать шум пошагово (как в DDPM), мы учим её находить прямой путь от шума к картинке. Да-да, просто рисуем линию из точки А в точку Б!

Как это работает:

1. Берём шум и настоящую картинку
2. Случайно выбираем точку между ними (это наше t)
3. Просим модель предсказать в какую сторону двигаться из этой точки

И всё! Вот честно - это весь алгоритм. Смотрите какой простой код для обучения:

def train_step(self, x0):
batch_size = len(x0)
z = torch.randn(batch_size, self.dim).to(self.device)
t = torch.rand(batch_size, 1).to(self.device)
xt = (1 - t) * z + t * x0 # линейная интерполяция между шумом и картинкой
pred_field = self.vector_field(xt, t)

true_field = x0 - z # вот оно - направление от шума к картинке
loss = F.mse_loss(pred_field, true_field)
return loss # возвращаем loss, а не x


А генерация ещё проще - просто идём маленькими шажками в нужном направлении:

def sample(self, batch_size=64, steps=100):
dt = 1.0 / steps
x = torch.randn(batch_size, self.dim).to(self.device)
for i in range(steps):
t = torch.ones(batch_size, 1).to(self.device) * i * dt
v = self.vector_field(x, t)
x = x + dt * v
return x


А теперь самое интересное - то что мы тут делаем, по сути решаем обычный дифур!

Наш vector_field это просто производная dx/dt, а в sample мы используем метод Эйлера для решения этого дифура. И тут открывается целое поле для экспериментов - можно использовать любые солверы: Рунге-Кутту, multistep методы и прочие штуки из мира численных методов.

В общем берите любой солвер из scipy.integrate и вперёд! Некоторые из них позволят ещё сильнее уменьшить количество шагов при генерации.

Главные преимущества по сравнению с DDPM:

- Не нужно возиться с расписаниями шума
- Процесс полностью детерминированный (мы же просто решаем дифур!)
- Генерация работает в разы быстрее
- Код настолько простой, что его можно написать за 5 минут

Я сам офигел когда первый раз это запустил - на многих задачах качество получается сравнимое с DDPM, а кода в три раза меньше.

Единственный небольшой минус - модель иногда бывает менее стабильной при обучении, т.к. нет стохастичности как в DDPM. Но это решается правильным подбором learning rate.

Flow Matching Guide and Code: https://arxiv.org/pdf/2412.06264

BY the last neural cell


Warning: Undefined variable $i in /var/www/group-telegram/post.php on line 260

Share with your friend now:
group-telegram.com/neural_cell/264

View MORE
Open in Telegram


Telegram | DID YOU KNOW?

Date: |

Asked about its stance on disinformation, Telegram spokesperson Remi Vaughn told AFP: "As noted by our CEO, the sheer volume of information being shared on channels makes it extremely difficult to verify, so it's important that users double-check what they read." And indeed, volatility has been a hallmark of the market environment so far in 2022, with the S&P 500 still down more than 10% for the year-to-date after first sliding into a correction last month. The CBOE Volatility Index, or VIX, has held at a lofty level of more than 30. After fleeing Russia, the brothers founded Telegram as a way to communicate outside the Kremlin's orbit. They now run it from Dubai, and Pavel Durov says it has more than 500 million monthly active users. But because group chats and the channel features are not end-to-end encrypted, Galperin said user privacy is potentially under threat. Sebi said data, emails and other documents are being retrieved from the seized devices and detailed investigation is in progress.
from br


Telegram the last neural cell
FROM American