ChatGPTのStreamモードをOkHttp-sse+Gson+Hiltを使って実装
2023-04-14
はじめに
最近ChatGPTが流行ってますね。ブラウザでChatGPTを使うと、レスポンスを一文字ずつ返してくれます。これをAPI経由で実現するためにはどうすればいいか調べてみると、APIにリクエストを送るときにstream
パラメータに対してtrue
を設定することで、実現できることがわかりました。この一文字ずつ返すのはServer Sent Events(SSE)として送られているみたいです。
実装
SSEを扱いやすくするために、okhttp-sseを使用します。今回はリクエスト部分などの紹介は省き、SSEで通信する箇所をメインに紹介します。実装したコードは、こちらのリポジトリにアップしています。
実装イメージとしてはこのようなものを想定しています。通信を完了するとSnackBarを表示するようにもしてみました。

OkHttpをRepository層にDIするためにHiltのProvidesを使います。
@Module
@InstallIn(SingletonComponent::class)
class OkHttpModule {
// TODO OpenAIのトークンを貼る
private val token = ""
@Singleton
@Provides
fun providesRequestBuilder(): Request.Builder {
return Request.Builder()
.url("https://api.openai.com/v1/chat/completions")
.header("Accept", "application/json")
.addHeader("Authorization", "Bearer $token")
}
@Singleton
@Provides
fun providesOkHttpClient(): OkHttpClient {
return OkHttpClient.Builder()
.readTimeout(10, TimeUnit.MINUTES)
.connectTimeout(10, TimeUnit.MINUTES)
.build()
}
}
Repository層
実際にOpenAIに対してデータ取得をするRepository層を作成します。okhttp-sseを使う上で特に重要な箇所は、EventSourceListenerクラスのonEventメソッドです。ここに送られてきた文字列(Json形式のときもあれば、ただの文字列のときもある)が入ってきます。
SSEの通信イベントをSSEEventとして定義し、StateFlow<SSEEvent>経由で他の層へ公開します。SSEEventはEventSourceListenerクラスのメソッドに加えて、初期状態を表すEmptyを追加しました。callbackFlowを使った実装も検討したのですが、StateFlowを使ったほうが個人的には実装しやすかったのでこちらにしました。
sealed interface SSEEvent {
object Empty : SSEEvent
object Open : SSEEvent
data class Event(val response: GPT35TurboResponse) : SSEEvent
data class Failure(val e: Throwable, val response: Response?) : SSEEvent
object Closed : SSEEvent
}
Repository層のコードです。onEventメソッドで送られてきたJsonレスポンスをGsonで変換します。ただし、ChatGPTが生成した文字列がすべて送られると[DONE]
という文字列が送られてくるので、このときはJsonの変換を行いようにします。
createFactory()
のnewEventSource()
を実行したタイミングで通信を開始します。
interface OpenAiRepository {
suspend fun postCompletions(gpt35Turbo: GPT35Turbo)
val state: StateFlow<SSEEvent>
}
@Singleton
class OpenAiRepositoryImpl @Inject constructor(
private val requestBuilder: Request.Builder,
private val client: OkHttpClient,
) : OpenAiRepository {
private val _state = MutableStateFlow<SSEEvent>(SSEEvent.Empty)
override val state = _state.asStateFlow()
private val gson: Gson = Gson()
private var eventSource: EventSource? = null
private val eventSourceListener = object : EventSourceListener() {
override fun onOpen(eventSource: EventSource, response: Response) {
super.onOpen(eventSource, response)
_state.value = SSEEvent.Open
}
override fun onEvent(eventSource: EventSource, id: String?, type: String?, data: String) {
super.onEvent(eventSource, id, type, data)
if (data != "[DONE]") {
val response = gson.fromJson(data, GPT35TurboResponse::class.java)
val message = response.choices[0].delta.content
_state.value = SSEEvent.Event(response)
}
}
override fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) {
super.onFailure(eventSource, t, response)
if (t != null) {
_state.value = SSEEvent.Failure(t, response)
}
}
override fun onClosed(eventSource: EventSource) {
super.onClosed(eventSource)
_state.value = SSEEvent.Closed
}
}
override suspend fun postCompletions(gpt35Turbo: GPT35Turbo) {
val requestBody = gson.toJson(gpt35Turbo)
val request = requestBuilder
.post(requestBody.toRequestBody("application/json; charset=UTF-8".toMediaTypeOrNull()))
.build()
withContext(Dispatchers.IO) {
eventSource = EventSources.createFactory(client)
.newEventSource(request, eventSourceListener)
}
}
}
Repository層のHiltの設定です。interfaceを作成しているのでBindsを使います。
@Module
@InstallIn(SingletonComponent::class)
abstract class RepositoryModule() {
@Binds
abstract fun bindOpenAiRepository(
openAiRepositoryImpl: OpenAiRepositoryImpl
): OpenAiRepository
}
UseCase層
Repository層で取得したデータを加工するためにUseCaseを実装します。今回はUseCase層でRepositoryから送られてきたイベントに応じて加工後、StateFlow<State>として公開します。
Stateの実装はこのように行いました。ほぼSSEEventと一緒ですが、UseCase層では、Repository層から送られてきた、GPT35TurboResponse内にあるChatGPTから送られてきた文字列を取り出し公開するために、StateのEventをStringにしています。(実装後の反省ですが、細かく分けすぎて逆にわかりにくくなった気がするのでStateを改めて定義する必要はなかったかもです)
sealed interface State {
object Empty: State
object Open : State
data class Event(val response: String) : State
data class Failure(val e: Throwable, val response: Response?) : State
object Closed : State
}
ほぼRepository層をラップしたものですが、SSEEvent.Eventのところで送られてきた文字列を取り出しています。
interface OpenAiUseCase {
suspend fun postCompletions(gpt35Turbo: GPT35Turbo)
fun cancelCompletions()
val state: StateFlow<State>
}
@Singleton
class OpenAiUseCaseImpl @Inject constructor(
private val repository: OpenAiRepository
) : OpenAiUseCase {
private val _state = MutableStateFlow<State>(State.Empty)
override val state = _state.asStateFlow()
override suspend fun postCompletions(gpt35Turbo: GPT35Turbo) {
repository.postCompletions(gpt35Turbo)
repository.state.collect { event ->
when (event) {
is SSEEvent.Empty -> _state.value = State.Empty
is SSEEvent.Open -> _state.value = State.Open
is SSEEvent.Event -> {
val value = event.response.choices.first().delta.content ?: ""
_state.value = State.Event(value)
}
is SSEEvent.Failure -> {
_state.value = State.Failure(event.e, event.response)
}
is SSEEvent.Closed -> _state.value = State.Closed
}
}
}
}
UseCase層のHiltの設定です。UseCase層もinterfaceを作成しているのでBindsを使います。
@Module
@InstallIn(SingletonComponent::class)
abstract class UseCaseModule {
@Binds
abstract fun bindOpenAiUseCase(
openAiUseCaseImpl: OpenAiUseCaseImpl
): OpenAiUseCase
}
ViewModel 層
UseCaseからデータを受取ってView層に渡してあげます。
@HiltViewModel
class MainActivityViewModel @Inject constructor(
private val useCase: OpenAiUseCase
) : ViewModel() {
data class UiState(
val generatedText: String = "",
)
sealed interface UiEvent {
data class ShowSnackBar(val message: String) : UiEvent
object Empty : UiEvent
}
private val _state = MutableStateFlow(UiState())
val state = _state.asStateFlow()
private val _event = Channel<UiEvent>()
val event: Flow<UiEvent> = _event.consumeAsFlow()
init {
viewModelScope.launch {
useCase.state.collect { state ->
when (state) {
is State.Event -> {
_state.update {
UiState(it.generatedText + state.response)
}
}
is State.Closed -> {
_event.send(UiEvent.ShowSnackBar("完了しました"))
}
else -> { }
}
}
}
}
fun start() {
viewModelScope.launch {
_state.update { UiState("") }
val messages = listOf(
Messages(Messages.Role.SYSTEM, "あなたは生粋の関西人です。"),
Messages(Messages.Role.ASSISTANT, "大阪名物について"),
Messages(Messages.Role.USER, "150文字以内かつ関西弁で紹介して下さい。"),
)
val gpt35Turbo = GPT35Turbo(messages = messages)
useCase.postCompletions(gpt35Turbo)
}
}
}
View層
Actvityに直接TextViewを作成しViewBindingで送ります。
@AndroidEntryPoint
class MainActivity : AppCompatActivity() {
private val viewModel: MainActivityViewModel by viewModels()
private lateinit var binding: ActivityMainBinding
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
binding = ActivityMainBinding.inflate(layoutInflater)
val view = binding.root
setContentView(view)
binding.button.setOnClickListener {
viewModel.start()
}
lifecycleScope.launch {
repeatOnLifecycle(Lifecycle.State.STARTED) {
viewModel.state.collect {
binding.textView.text = it.generatedText
}
}
}
lifecycleScope.launch {
repeatOnLifecycle(Lifecycle.State.STARTED) {
viewModel.event.collect { event ->
when (event) {
is MainActivityViewModel.UiEvent.ShowSnackBar -> {
Snackbar.make(view, event.message, Snackbar.LENGTH_LONG).show()
}
else -> {}
}
}
}
}
}
}
最後に
ChatGPTのStream APIを組み合わせて使ってみました。通常のRESTのAPIを使っても良いのですが、ChatGPTからの応答時間がかかるため、実際にアプリの機能として落としこむためには、ユーザへの見せ方を工夫する必要があります。しかし、このStreamを使った実装方法では、リアルタイムに結果をユーザへ反映できるため飽きさせない、かつ、わくわくする見せ方を可能とします。可能であればこの実装方法で機能に組込みたいですね。