BlogHome

Song search in Rust using OpenAI

2023-01-14

Intro

Ever happens to you that you hear a good song in a TV show, in a Movie, in cafe etc. And wonder what song is it?

In today's age it is possible to search it on ChatGPT or Shazam, however, in this article we'll explore Embeddings to search a song using few words from lyrics, artist, album or song title. We'll use Rust + OpenAI to do so.

The complete code and data for this is available on song-search-rust-openai Github repo.

Song lyrics data

We'll use Song lyrics dataset from Kaggle. It contains .csv files for each Artist with header in PascalCase. Each record (each row) is represented in Rust as follows:

#[derive(Debug, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub struct Song {
    pub artist: String,
    pub title: String,
    pub album: String,
    pub lyric: String,
}

To load all the songs into memory we use csv crate. It supports serde deserialization to load data into custom struct:

impl Song {
    //// Read csv at given path into Vec<Song>
    pub fn get_songs<P: AsRef<Path>>(path: P) -> Result<Vec<Song>> {
        let mut rdr = csv::Reader::from_path(path)?;
        let mut songs = vec![];
        for result in rdr.deserialize() {
            let song: Song = result?;
            songs.push(song);
        }
        Ok(songs)
    }
}

What are embeddings?

Embedding is a vector of floating point numbers. The distance between two such vectors measure how related they are. Smaller the distance more related they are. Among various distance function, we'll use cosine similarity recommended by OpenAI for distance between two vectors.

flowchart LR; id([text]) --> embeddingmodel; embeddingmodel[Embedding model] --> vector[-0.001, 0.023, ..., 0.255];

How will we use embedding for song search?

OpenAI provides models for creating embeddings from text. To create Embedding for Song data, we concatenate each field of Song into one big text and pass it to OpenAI API as single input.

flowchart LR; query([Song \n + artist \n + title \n + album \n + lyrics text]) --> embeddingmodel; embeddingmodel[OpenAI Embedding model] --> vector[0.002, 0.034, ..., 0.298]; vector --> db[(\nStore vector\nin DB)];

Get song embedding using async-openai crate:

pub const EMBEDDING_MODEL: &str = "text-embedding-ada-002";

impl Song {
    /// Append artist + title + album + lyrics as input text for
    /// Song embedding
    pub fn embedding_text(&self) -> String {
        [&self.artist, &self.title, &self.album, &self.lyric]
            .map(|s| s.to_owned())
            .join(" ")
            .replace("\n", " ")
            .trim()
            .to_lowercase()
    }

    /// Get embedding from OpenAI for this Song
    pub async fn get_embedding(&self, openai_client: &Client)
      -> Result<CreateEmbeddingResponse> {
        let response = openai_client
            .embeddings()
            .create(
                CreateEmbeddingRequestArgs::default()
                    .input(self.embedding_text())
                    .model(EMBEDDING_MODEL)
                    .build()?,
            )
            .await?;

        info!(
            "Song {} by {} used {} tokens",
            self.title, self.artist, response.usage.total_tokens
        );

        Ok(response)
    }
}

And store the returned embedding in a database along with all fields of Song struct. But a small detour first:



Vector Database

During my non-exhaustive search I found various vector databases and libraries like pinecone, milvus, Weaviate, and Faiss. And there are many more. But none of them seem to have an out-of-the-box ready to go library in Rust. Except pgvector - a Postgres extension to store and query vectors! The pgvector project provides Docker image with extension already installed as well as Rust library 😎

The extension, table, and index query to setup DB:

impl Song {
   /// Create extension, table and index
   pub async fn create_db_resources(pg_pool: &PgPool) -> Result<()> {
       // Create Extension
       let extension_query = "CREATE EXTENSION IF NOT EXISTS vector;";
       sqlx::query(extension_query).execute(pg_pool).await?;

       // Create Table
       let table_query = r#"CREATE TABLE IF NOT EXISTS songs
(
   artist text,
   title text,
   album text,
   lyric text,
   embedding vector(1536)
);"#;

       sqlx::query(table_query).execute(pg_pool).await?;

      // Create Index for Cosine Similarity
       let index_query =
           r#"CREATE INDEX IF NOT EXISTS
              songs_idx ON songs
              USING ivfflat (embedding vector_cosine_ops);"#;

       sqlx::query(index_query).execute(pg_pool).await?;

       Ok(())
   }
}

So back to storing the embedding in DB along with Song fields:

impl Song {
    /// Save embedding for this Song in DB
    pub async fn save_embedding(&self, pg_pool: &PgPool,
      pgvector: pgvector::Vector) -> Result<()> {

        sqlx::query(r#"INSERT INTO songs
            (artist, title, album, lyric, embedding)
            VALUES ($1, $2, $3, $4, $5)"#)
            .bind(self.artist.clone())
            .bind(self.title.clone())
            .bind(self.album.clone())
            .bind(self.lyric.clone())
            .bind(pgvector)
            .execute(pg_pool)
            .await?;

        Ok(())
    }
}

For search query, we'll ask OpenAI model to create embedding for the query text - and find the N nearest embeddings for it in our database - representing our search results.

flowchart LR; query([query words]) --> embeddingmodel; embeddingmodel[OpenAI Embedding model] --> vector[-0.071, 0.09, ..., 0.212]; vector --> db[(\nGet N-nearest \nmatching vectors\n from DB)];
impl Song {
    /// Search `n` nearest neighbors for given query in DB
    pub async fn query(query: &str, n: i8,
      client: &Client, pg_pool: &PgPool) -> Result<Vec<Song>> {

        let query = query.trim().to_lowercase();

        // Get embedding from OpenAI
        let response = client
            .embeddings()
            .create(
                CreateEmbeddingRequestArgs::default()
                    .input(query)
                    .model(EMBEDDING_MODEL)
                    .build()?,
            )
            .await?;

        let pgvector = pgvector::Vector::from(
            response.data[0].embedding.clone());

        // Search for nearest neighbors in database
        Ok(sqlx::query(
            r#"SELECT artist, title, album, lyric
               FROM songs ORDER BY embedding <-> $1 LIMIT $2::int"#,
        )
        .bind(pgvector)
        .bind(n)
        .fetch_all(pg_pool)
        .await?
        .into_iter()
        .map(|r| Song {
            artist: r.get("artist"),
            title: r.get("title"),
            album: r.get("album"),
            lyric: r.get("lyric"),
        })
        .collect())
    }
}

Search queries

Now that we have embeddings for the whole Kaggle dataset. Lets get 10 nearest neighbors for each query.

Here are some sample queries and their result, notice how it shows result with similar meaning as original query

Query: broken heart
Heartache by Justin Bieber / Unreleased Songs
I Heart ? by Taylor Swift / Taylor Swift (Best Buy Exclusive) 
Them Changes (BBC Radio 1 Live Lounge) by Ariana Grande / BBC Radio 1's Live Lounge 2018
Heartbeat by Beyoncé / Unreleased Songs 
Broken-Hearted Girl (Catalyst Remix) by Beyoncé / Above and Beyoncé - Dance Mixes
A Perfectly Good Heart by Taylor Swift / Taylor Swift
Hopelessly Devoted To You by Taylor Swift
Heart-Shaped Box by Post Malone
Don’t Let It Break Your Heart by Coldplay / Mylo Xyloto
This Love by Maroon 5 / Songs About Jane
Query: phone
Payphone (No Rap Edit) by Maroon 5
Telephone (DJ Dan Vocal Remix) by Lady Gaga / Telephone (The Remixes)
Drake’s Voice Mail Box #3 by Drake / Room for Improvement
Telephone by Lady Gaga / The Fame Monster
hotline bling by Billie Eilish / party favor / hotline bling
Video Phone by Beyoncé / I Am... Sasha Fierce
Text You Pictures by Lady Gaga / Unreleased Songs
Telephone (Passion Pit Remix) by Lady Gaga / The Remix
Telephone (Demo/Solo Version) by Lady Gaga / Unreleased Songs
Video Phone (Extended Remix) by Beyoncé / I Am... Sasha Fierce
Query: your shape
Diamonds by Beyoncé
Shape of you -In love with the shape of you by Ed Sheeran
Shape of You by Ed Sheeran / ÷ (Divide)
NOT MY RESPONSIBILITY by Billie Eilish
Shape of You (NOTD Remix) by Ed Sheeran
Shape of You (Berywam Remix) by Ed Sheeran / Covers
Shape of You (Galantis Remix) by Ed Sheeran
Shape of You (Major Lazer Remix) by Ed Sheeran
Shape of You / XO TOUR Llif3 2017 VMAs by Ed Sheeran
Shape of You (Stormzy Remix) by Ed Sheeran
Query: korean pop
Shine (Yunki Theme) by BTS (방탄소년단) / BTS WORLD (Original Soundtrack)
좋아요 Pt. 2 (I Like it Pt. 2) by BTS (방탄소년단)
힙합성애자 (Hip Hop Lover/Phile) by BTS (방탄소년단) / Dark&Wild
Captain (Namjun Theme) by BTS (방탄소년단) / BTS WORLD (Original Soundtrack)
ペップセ (Peppuse) (Silver Spoon) (Japanese Ver.) by BTS (방탄소년단) / YOUTH
INTRO by BTS (방탄소년단) / WAKE UP
In The SOOP by BTS (방탄소년단)
Arirang by BTS (방탄소년단)
Black Swan - Japanese ver. by BTS (방탄소년단) / MAP OF THE SOUL: 7 ~ The Journey ~
팔도강산 Paldogangsan (Satoori Rap) by BTS (방탄소년단) / O!RUL8,2?

To play and try more queries checkout song-search-rust-openai Github repo.

Conclusion

Being able to quickly build a contextual/relevant search in Rust with less than $1.5 on OpenAI for 6027 songs in dataset is powerful and fun!

Acknowledgements

This project was made possible by generous work of @ankane who released feature to support Vectors of max dimension of 2000 in version 0.4.0 of pgvector project. (The output dimension of latest OpenAI embedding model is 1536)