ChatGPTのStreamモードをOkHttp-sse+Gson+Hiltを使って実装

2023-04-14

はじめに

最近ChatGPTが流行ってますね。ブラウザでChatGPTを使うと、レスポンスを一文字ずつ返してくれます。これをAPI経由で実現するためにはどうすればいいか調べてみると、APIにリクエストを送るときにstreamパラメータに対してtrueを設定することで、実現できることがわかりました。この一文字ずつ返すのはServer Sent Events(SSE)として送られているみたいです。

実装

SSEを扱いやすくするために、okhttp-sseを使用します。今回はリクエスト部分などの紹介は省き、SSEで通信する箇所をメインに紹介します。実装したコードは、こちらのリポジトリにアップしています。

実装イメージとしてはこのようなものを想定しています。通信を完了するとSnackBarを表示するようにもしてみました。

完成のイメージ

OkHttpをRepository層にDIするためにHiltのProvidesを使います。

OkHttpModule.kt
@Module
@InstallIn(SingletonComponent::class)
class OkHttpModule {

    // TODO OpenAIのトークンを貼る
    private val token = ""

    @Singleton
    @Provides
    fun providesRequestBuilder(): Request.Builder {
        return Request.Builder()
            .url("https://api.openai.com/v1/chat/completions")
            .header("Accept", "application/json")
            .addHeader("Authorization", "Bearer $token")
    }

    @Singleton
    @Provides
    fun providesOkHttpClient(): OkHttpClient {
        return OkHttpClient.Builder()
            .readTimeout(10, TimeUnit.MINUTES)
            .connectTimeout(10, TimeUnit.MINUTES)
            .build()
    }
}

Repository層

実際にOpenAIに対してデータ取得をするRepository層を作成します。okhttp-sseを使う上で特に重要な箇所は、EventSourceListenerクラスのonEventメソッドです。ここに送られてきた文字列(Json形式のときもあれば、ただの文字列のときもある)が入ってきます。

SSEの通信イベントをSSEEventとして定義し、StateFlow<SSEEvent>経由で他の層へ公開します。SSEEventはEventSourceListenerクラスのメソッドに加えて、初期状態を表すEmptyを追加しました。callbackFlowを使った実装も検討したのですが、StateFlowを使ったほうが個人的には実装しやすかったのでこちらにしました。

SSEEvent.kt
sealed interface SSEEvent {
    object Empty : SSEEvent
    object Open : SSEEvent
    data class Event(val response: GPT35TurboResponse) : SSEEvent
    data class Failure(val e: Throwable, val response: Response?) : SSEEvent
    object Closed : SSEEvent
}

Repository層のコードです。onEventメソッドで送られてきたJsonレスポンスをGsonで変換します。ただし、ChatGPTが生成した文字列がすべて送られると[DONE]という文字列が送られてくるので、このときはJsonの変換を行いようにします。 createFactory()newEventSource()を実行したタイミングで通信を開始します。

OpenAiRepository.kt
interface OpenAiRepository {
    suspend fun postCompletions(gpt35Turbo: GPT35Turbo)
    val state: StateFlow<SSEEvent>
}

@Singleton
class OpenAiRepositoryImpl @Inject constructor(
    private val requestBuilder: Request.Builder,
    private val client: OkHttpClient,
) : OpenAiRepository {

    private val _state = MutableStateFlow<SSEEvent>(SSEEvent.Empty)
    override val state = _state.asStateFlow()

    private val gson: Gson = Gson()
    private var eventSource: EventSource? = null
    private val eventSourceListener = object : EventSourceListener() {

        override fun onOpen(eventSource: EventSource, response: Response) {
            super.onOpen(eventSource, response)
            _state.value = SSEEvent.Open
        }

        override fun onEvent(eventSource: EventSource, id: String?, type: String?, data: String) {
            super.onEvent(eventSource, id, type, data)

            if (data != "[DONE]") {
                val response = gson.fromJson(data, GPT35TurboResponse::class.java)
                val message = response.choices[0].delta.content
                _state.value = SSEEvent.Event(response)
            }
        }

        override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) {
            super.onFailure(eventSource, t, response)

            if (t != null) {
                _state.value = SSEEvent.Failure(t, response)
            }
        }

        override fun onClosed(eventSource: EventSource) {
            super.onClosed(eventSource)

            _state.value = SSEEvent.Closed
        }
    }

    override suspend fun postCompletions(gpt35Turbo: GPT35Turbo) {
        val requestBody = gson.toJson(gpt35Turbo)

        val request = requestBuilder
            .post(requestBody.toRequestBody("application/json; charset=UTF-8".toMediaTypeOrNull()))
            .build()

        withContext(Dispatchers.IO) {
            eventSource = EventSources.createFactory(client)
                .newEventSource(request, eventSourceListener)
        }
    }
}

Repository層のHiltの設定です。interfaceを作成しているのでBindsを使います。

RepositoryModule.kt
@Module
@InstallIn(SingletonComponent::class)
abstract class RepositoryModule() {

    @Binds
    abstract fun bindOpenAiRepository(
        openAiRepositoryImpl: OpenAiRepositoryImpl
    ): OpenAiRepository
}

UseCase層

Repository層で取得したデータを加工するためにUseCaseを実装します。今回はUseCase層でRepositoryから送られてきたイベントに応じて加工後、StateFlow<State>として公開します。

Stateの実装はこのように行いました。ほぼSSEEventと一緒ですが、UseCase層では、Repository層から送られてきた、GPT35TurboResponse内にあるChatGPTから送られてきた文字列を取り出し公開するために、StateのEventをStringにしています。(実装後の反省ですが、細かく分けすぎて逆にわかりにくくなった気がするのでStateを改めて定義する必要はなかったかもです)

State.kt
sealed interface State {
    object Empty: State
    object Open : State
    data class Event(val response: String) : State
    data class Failure(val e: Throwable, val response: Response?) : State
    object Closed : State
}

ほぼRepository層をラップしたものですが、SSEEvent.Eventのところで送られてきた文字列を取り出しています。

OpenAiUseCase.kt
interface OpenAiUseCase {
    suspend fun postCompletions(gpt35Turbo: GPT35Turbo)
    fun cancelCompletions()
    val state: StateFlow<State>
}

@Singleton
class OpenAiUseCaseImpl @Inject constructor(
    private val repository: OpenAiRepository
) : OpenAiUseCase {

    private val _state = MutableStateFlow<State>(State.Empty)
    override val state = _state.asStateFlow()

    override suspend fun postCompletions(gpt35Turbo: GPT35Turbo) {
        repository.postCompletions(gpt35Turbo)
        repository.state.collect { event ->

            when (event) {
                is SSEEvent.Empty -> _state.value = State.Empty
                is SSEEvent.Open -> _state.value = State.Open
                is SSEEvent.Event -> {
                    val value = event.response.choices.first().delta.content ?: ""
                    _state.value = State.Event(value)
                }
                is SSEEvent.Failure -> {
                    _state.value = State.Failure(event.e, event.response)
                }
                is SSEEvent.Closed -> _state.value = State.Closed
            }
        }
    }
}

UseCase層のHiltの設定です。UseCase層もinterfaceを作成しているのでBindsを使います。

UseCaseModule.kt
@Module
@InstallIn(SingletonComponent::class)
abstract class UseCaseModule {

    @Binds
    abstract fun bindOpenAiUseCase(
        openAiUseCaseImpl: OpenAiUseCaseImpl
    ): OpenAiUseCase
}

ViewModel層

UseCaseからデータを受取ってView層に渡してあげます。

MainActivityViewModel.kt
@HiltViewModel
class MainActivityViewModel @Inject constructor(
    private val useCase: OpenAiUseCase
) : ViewModel() {

    data class UiState(
        val generatedText: String = "",
    )

    sealed interface UiEvent {
        data class ShowSnackBar(val message: String) : UiEvent
        object Empty : UiEvent
    }

    private val _state = MutableStateFlow(UiState())
    val state = _state.asStateFlow()
    private val _event = Channel<UiEvent>()
    val event: Flow<UiEvent> = _event.consumeAsFlow()

    init {
        viewModelScope.launch {
            useCase.state.collect { state ->
                when (state) {
                    is State.Event -> {
                        _state.update {
                            UiState(it.generatedText + state.response)
                        }
                    }
                    is State.Closed -> {
                        _event.send(UiEvent.ShowSnackBar("完了しました"))
                    }
                    else -> { }
                }
            }
        }
    }

    fun start() {
        viewModelScope.launch {
            _state.update { UiState("") }
            val messages = listOf(
                Messages(Messages.Role.SYSTEM, "あなたは生粋の関西人です。"),
                Messages(Messages.Role.ASSISTANT, "大阪名物について"),
                Messages(Messages.Role.USER, "150文字以内かつ関西弁で紹介して下さい。"),
            )
            val gpt35Turbo = GPT35Turbo(messages = messages)
            useCase.postCompletions(gpt35Turbo)
        }
    }
}

View層

Actvityに直接TextViewを作成しViewBindingで送ります。

OpenAiRepository.kt
@AndroidEntryPoint
class MainActivity : AppCompatActivity() {

    private val viewModel: MainActivityViewModel by viewModels()
    private lateinit var binding: ActivityMainBinding

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)

        binding = ActivityMainBinding.inflate(layoutInflater)
        val view = binding.root
        setContentView(view)

        binding.button.setOnClickListener {
            viewModel.start()
        }

        lifecycleScope.launch {
            repeatOnLifecycle(Lifecycle.State.STARTED) {
                viewModel.state.collect {
                    binding.textView.text = it.generatedText
                }
            }
        }
        lifecycleScope.launch {
            repeatOnLifecycle(Lifecycle.State.STARTED) {
                viewModel.event.collect { event ->
                    when (event) {
                        is MainActivityViewModel.UiEvent.ShowSnackBar -> {
                            Snackbar.make(view, event.message, Snackbar.LENGTH_LONG).show()
                        }
                        else -> {}
                    }
                }
            }
        }
    }
}

最後に

ChatGPTのStream APIを組み合わせて使ってみました。通常のRESTのAPIを使っても良いのですが、ChatGPTからの応答時間がかかるため、実際にアプリの機能として落としこむためには、ユーザへの見せ方を工夫する必要があります。しかし、このStreamを使った実装方法では、リアルタイムに結果をユーザへ反映できるため飽きさせない、かつ、わくわくする見せ方を可能とします。可能であればこの実装方法で機能に組込みたいですね。

Tatsumi0000

Written by Tatsumi0000 モバイル開発が好きなエンジニアのブログです. GitHub

Copyright © 2023, Tatsumi0000 All Rights Reserved.