From 645a987cf15db30e815291722865030611ed623a Mon Sep 17 00:00:00 2001 From: Himadri Bhattacharjee <107522312+lavafroth@users.noreply.github.com> Date: Sun, 29 Jun 2025 19:58:30 +0530 Subject: [PATCH 1/6] feat: scaffolding for lsp very sharp edges --- .helix/languages.toml | 6 ++ Cargo.lock | 142 +++++++++++++++++++++++++- Cargo.toml | 4 +- src/main.rs | 224 +++++++++++++++++++++++++++++++++++++++--- src/v2/api.rs | 2 +- 5 files changed, 361 insertions(+), 17 deletions(-) create mode 100644 .helix/languages.toml diff --git a/.helix/languages.toml b/.helix/languages.toml new file mode 100644 index 0000000..c7933bf --- /dev/null +++ b/.helix/languages.toml @@ -0,0 +1,6 @@ +[language-server.silos] +command = "./target/debug/silos" + +[[language]] +name = "go" +language-servers = [ { name = "silos" } ] diff --git a/Cargo.lock b/Cargo.lock index d25f47d..96dd42c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -289,12 +289,34 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "async-trait" +version = "0.1.88" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "auto_impl" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffdcb70bdbc4d478427380519163274ac86e52916e10f0a8889adf0f96d3fee7" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "autocfg" version = "1.4.0" @@ -694,6 +716,19 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "deranged" version = "0.4.0" @@ -1377,6 +1412,12 @@ dependencies = [ "rand_distr", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.3" @@ -1695,7 +1736,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.3", ] [[package]] @@ -1878,6 +1919,19 @@ version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +[[package]] +name = "lsp-types" +version = "0.94.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c66bfd44a06ae10647fe3f8214762e9369fd4248df1350924b4ef9e770a85ea1" +dependencies = [ + "bitflags 1.3.2", + "serde", + "serde_json", + "serde_repr", + "url", +] + [[package]] name = "macro_rules_attribute" version = "0.2.2" @@ -2275,6 +2329,26 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pin-project" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.16" @@ -2596,7 +2670,7 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-util", - "tower", + "tower 0.5.2", "tower-http", "tower-service", "url", @@ -2791,6 +2865,17 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_repr" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -2856,6 +2941,8 @@ dependencies = [ "serde", "serde_json", "tokenizers", + "tokio", + "tower-lsp", "tracing", "tracing-subscriber", "tree-sitter", @@ -3230,6 +3317,20 @@ dependencies = [ "winnow 0.7.10", ] +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", + "tower-layer", + "tower-service", +] + [[package]] name = "tower" version = "0.5.2" @@ -3258,7 +3359,7 @@ dependencies = [ "http-body", "iri-string", "pin-project-lite", - "tower", + "tower 0.5.2", "tower-layer", "tower-service", ] @@ -3269,6 +3370,40 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" +[[package]] +name = "tower-lsp" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4ba052b54a6627628d9b3c34c176e7eda8359b7da9acd497b9f20998d118508" +dependencies = [ + "async-trait", + "auto_impl", + "bytes", + "dashmap", + "futures", + "httparse", + "lsp-types", + "memchr", + "serde", + "serde_json", + "tokio", + "tokio-util", + "tower 0.4.13", + "tower-lsp-macros", + "tracing", +] + +[[package]] +name = "tower-lsp-macros" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84fd902d4e0b9a4b27f2f440108dc034e1758628a9b702f8ec61ad66355422fa" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tower-service" version = "0.3.3" @@ -3496,6 +3631,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index fa4c799..8cd3a24 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "silos" -version = "1.0.0" +version = "1.1.0" edition = "2024" [dependencies] @@ -24,3 +24,5 @@ tree-sitter = "0.25.6" tree-sitter-c = "0.24.1" tree-sitter-go = "0.23.4" tree-sitter-rust = "0.24.0" +tokio = { version = "1.45.1", features = ["io-std", "macros", "rt", "rt-multi-thread"] } +tower-lsp = "0.20.0" diff --git a/src/main.rs b/src/main.rs index 1888865..f5b0293 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,19 @@ +use std::sync::Arc; + +use actix_web::web::Data; use actix_web::{App, HttpServer, web}; use anyhow::{Context, Error as E, Result, bail}; use clap::Parser; use hora::core::ann_index::ANNIndex; use hora::index::hnsw_idx::HNSWIndex; use kdl::KdlDocument; -use state::State; +use state::{State, StateWrapper}; use std::collections::HashMap; +use tokio::sync::Mutex; +use tower_lsp::lsp_types::*; +use tower_lsp::{Client, LanguageServer, LspService, Server}; +use tracing::{error, info}; +use tree_sitter::Parser as TSParser; mod embed; mod state; @@ -15,6 +23,8 @@ mod v2; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { + mode: Option, + /// Run on the Nth GPU device. #[arg(long)] gpu: Option, @@ -63,6 +73,7 @@ async fn main() -> Result<()> { tracing_subscriber::fmt::init(); let args = Args::parse(); let port = args.port; + let mode = args.mode.clone(); let mut embed = embed::Embed::new(args)?; let mut dict = HashMap::default(); @@ -139,15 +150,204 @@ async fn main() -> Result<()> { let appstate_wrapped = web::Data::new(appstate.build()); - HttpServer::new(move || { - App::new() - .app_data(appstate_wrapped.clone()) - .service(v1::api::get_snippet) - .service(v1::api::add_snippet) - .service(v2::api::get_snippet) - }) - .bind(("127.0.0.1", port))? - .run() - .await - .map_err(E::from) + if mode.is_some_and(|v| v == "http") { + HttpServer::new(move || { + App::new() + .app_data(appstate_wrapped.clone()) + .service(v1::api::get_snippet) + .service(v1::api::add_snippet) + .service(v2::api::get_snippet) + }) + .bind(("127.0.0.1", port))? + .run() + .await + .map_err(E::from) + } else { + let stdin = tokio::io::stdin(); + let stdout = tokio::io::stdout(); + + let (service, socket) = LspService::new(|client| Backend { + client, + body: Arc::new(Mutex::new(String::default())), + appstate: appstate_wrapped.clone(), + }); + Server::new(stdin, stdout, socket).serve(service).await; + Ok(()) + } +} + +struct Backend { + client: Client, + body: Arc>, + appstate: Data, +} + +pub fn string_range_index(s: &str, r: Range) -> &str { + let mut newline_count = 0; + let mut start = None; + let mut end = None; + for (i, c) in s.chars().enumerate() { + if newline_count == r.start.line && start.is_none() { + start.replace(i + r.start.character as usize); + } + + if newline_count == r.end.line && end.is_none() { + end.replace(i + r.end.character as usize); + } + if c == '\n' { + newline_count += 1; + } + } + &s[start.unwrap_or_default()..end.unwrap_or(s.len())] +} + +#[tower_lsp::async_trait] +impl LanguageServer for Backend { + async fn initialize( + &self, + _: InitializeParams, + ) -> tower_lsp::jsonrpc::Result { + Ok(InitializeResult { + capabilities: ServerCapabilities { + text_document_sync: Some(TextDocumentSyncCapability::Kind( + TextDocumentSyncKind::FULL, + )), + code_action_provider: Some( + tower_lsp::lsp_types::CodeActionProviderCapability::Options( + CodeActionOptions::default(), + ), + ), + ..Default::default() + }, + ..Default::default() + }) + } + + async fn initialized(&self, _: InitializedParams) { + self.client + .log_message(MessageType::INFO, "server initialized!") + .await; + } + + async fn shutdown(&self) -> tower_lsp::jsonrpc::Result<()> { + Ok(()) + } + + async fn did_open(&self, params: DidOpenTextDocumentParams) { + // TODO: build an index for multiple documents in workdir + *self.body.lock().await = params.text_document.text; + } + + async fn did_change(&self, params: DidChangeTextDocumentParams) { + if let Some(body) = params.content_changes.into_iter().next() { + *self.body.lock().await = body.text; + } + } + + async fn code_action( + &self, + params: CodeActionParams, + ) -> tower_lsp::jsonrpc::Result> { + let uri = params.text_document.uri; + let body = self.body.lock().await.to_string(); + + let range = params.range; + let new_text = string_range_index(&body, range); + let Some((_before, after)) = new_text.split_once("silos: ") else { + return Ok(None); + }; + let Some((desc, _after)) = after.split_once("\n") else { + return Ok(None); + }; + + + let Some((prompt, lang)) = desc.rsplit_once(" in ") else { + error!("{}", v2::errors::Error::MissingSuffix); + return Ok(None); + }; + + let langfn = match v2::api::get_lang(lang) { + Ok(o) => o, + Err(e) => { + error!("{e}"); + return Ok(None); + } + }; + + info!(prompt = prompt, language = lang, "v2 request"); + + let mut appstate = self + .appstate + .inner + .lock() + .map_err(|_| v2::errors::Error::Busy) + .expect("booo"); + let target = appstate + .embed + .embed(prompt) + .map_err(|_| v2::errors::Error::EmbedFailed) + .expect("booo"); + let mut parser = TSParser::new(); + parser + .set_language(&langfn) + .map_err(|_| v2::errors::Error::UnknownLang) + .expect("boo"); + + let source_code = new_text; + let source_bytes = source_code.as_bytes(); + let tree = parser + .parse(source_code, None) + .ok_or(v2::errors::Error::SnippetParsing) + .expect("boo"); + let root_node = tree.root_node(); + + // search for k nearest neighbors + let closest: Vec = appstate.v2.dict[lang] + .search(&target, 1) + .iter() + .filter_map(|&index| { + let applied = v2::mutation::apply( + langfn.clone(), + source_bytes, + root_node, + &appstate.v2.mutations_collection[index], + ); + match applied { + Ok(v) => Some(v), + Err(e) => { + error!( + collection_index = index, + "failed to apply mutations from collection {}", e + ); + None + } + } + // TODO: change the expect to a log + }) + .collect(); + + let closest = closest[0].clone(); + + let text_edit = TextEdit { + range, + new_text: closest, + }; + let changes: HashMap = [(uri, vec![text_edit])].into_iter().collect(); + let edit = WorkspaceEdit { + changes: Some(changes), + document_changes: None, + change_annotations: None, + }; + let actions = vec![CodeActionOrCommand::CodeAction(CodeAction { + title: "ask silos".to_string(), + kind: None, + diagnostics: None, + edit: Some(edit), + command: None, + is_preferred: None, + disabled: None, + data: None, + })]; + Ok(Some(actions)) + } } diff --git a/src/v2/api.rs b/src/v2/api.rs index a8346ec..cad0139 100644 --- a/src/v2/api.rs +++ b/src/v2/api.rs @@ -33,7 +33,7 @@ pub struct Snippet { body: String, } -fn get_lang(s: &str) -> Result { +pub fn get_lang(s: &str) -> Result { Ok(match s { "go" => tree_sitter_go::LANGUAGE, "c" => tree_sitter_c::LANGUAGE, From 4b710e867582c7f7ca62bc232c66ab31a8c72011 Mon Sep 17 00:00:00 2001 From: Himadri Bhattacharjee <107522312+lavafroth@users.noreply.github.com> Date: Mon, 30 Jun 2025 07:35:32 +0530 Subject: [PATCH 2/6] feat: add mode handler for http or lsp --- Cargo.lock | 2 +- README.md | 2 +- src/main.rs | 57 ++++++++++++++++++++++++++++++++++------------------- 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 96dd42c..e6fb8af 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2925,7 +2925,7 @@ dependencies = [ [[package]] name = "silos" -version = "1.0.0" +version = "1.1.0" dependencies = [ "actix-web", "anyhow", diff --git a/README.md b/README.md index 9dc496a..a969e56 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ cd silos ``` ``` sh -cargo r +cargo r http ``` > [!NOTE] diff --git a/src/main.rs b/src/main.rs index f5b0293..2b2ad67 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,6 +23,7 @@ mod v2; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { + /// The mode to run the server in. Defaults to LSP. The HTTP REST API can be run by specifying `http` or `http:port`. For example: `http:7047` mode: Option, /// Run on the Nth GPU device. @@ -36,10 +37,11 @@ struct Args { /// Revision or branch. #[arg(long)] revision: Option, +} - /// The port for the API to listen on - #[arg(long, default_value = "8000")] - port: u16, +pub enum RunMode { + Http(u16), + Lsp, } impl Args { @@ -54,6 +56,23 @@ impl Args { (None, None) => (default_model, default_revision), } } + fn mode(&self) -> RunMode { + let Some(http) = &self.mode else { + return RunMode::Lsp; + }; + if http == "http" { + return RunMode::Http(8000); + } + let Some(port) = http.strip_prefix("http:") else { + return RunMode::Lsp; + }; + + let Ok(port) = port.parse() else { + return RunMode::Lsp; + }; + + RunMode::Http(port) + } } fn path_to_parent_base(p: &std::path::Path) -> Result { @@ -72,8 +91,7 @@ fn path_to_parent_base(p: &std::path::Path) -> Result { async fn main() -> Result<()> { tracing_subscriber::fmt::init(); let args = Args::parse(); - let port = args.port; - let mode = args.mode.clone(); + let mode = args.mode(); let mut embed = embed::Embed::new(args)?; let mut dict = HashMap::default(); @@ -150,8 +168,8 @@ async fn main() -> Result<()> { let appstate_wrapped = web::Data::new(appstate.build()); - if mode.is_some_and(|v| v == "http") { - HttpServer::new(move || { + if let RunMode::Http(port) = mode { + return HttpServer::new(move || { App::new() .app_data(appstate_wrapped.clone()) .service(v1::api::get_snippet) @@ -161,19 +179,19 @@ async fn main() -> Result<()> { .bind(("127.0.0.1", port))? .run() .await - .map_err(E::from) - } else { - let stdin = tokio::io::stdin(); - let stdout = tokio::io::stdout(); + .map_err(E::from); + }; - let (service, socket) = LspService::new(|client| Backend { - client, - body: Arc::new(Mutex::new(String::default())), - appstate: appstate_wrapped.clone(), - }); - Server::new(stdin, stdout, socket).serve(service).await; - Ok(()) - } + let stdin = tokio::io::stdin(); + let stdout = tokio::io::stdout(); + + let (service, socket) = LspService::new(|client| Backend { + client, + body: Arc::new(Mutex::new(String::default())), + appstate: appstate_wrapped.clone(), + }); + Server::new(stdin, stdout, socket).serve(service).await; + Ok(()) } struct Backend { @@ -259,7 +277,6 @@ impl LanguageServer for Backend { let Some((desc, _after)) = after.split_once("\n") else { return Ok(None); }; - let Some((prompt, lang)) = desc.rsplit_once(" in ") else { error!("{}", v2::errors::Error::MissingSuffix); From 716b9ed3e23439cbb4b953130eb275d545d488e5 Mon Sep 17 00:00:00 2001 From: Himadri Bhattacharjee <107522312+lavafroth@users.noreply.github.com> Date: Tue, 1 Jul 2025 18:58:18 +0530 Subject: [PATCH 3/6] refactor: move common closest mutation search into a function --- src/main.rs | 101 +++++++++++++++----------------------------------- src/v2/api.rs | 25 +++++++++++-- 2 files changed, 50 insertions(+), 76 deletions(-) diff --git a/src/main.rs b/src/main.rs index 2b2ad67..75a9f6c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,8 +12,7 @@ use std::collections::HashMap; use tokio::sync::Mutex; use tower_lsp::lsp_types::*; use tower_lsp::{Client, LanguageServer, LspService, Server}; -use tracing::{error, info}; -use tree_sitter::Parser as TSParser; +use tracing::error; mod embed; mod state; @@ -267,6 +266,7 @@ impl LanguageServer for Backend { params: CodeActionParams, ) -> tower_lsp::jsonrpc::Result> { let uri = params.text_document.uri; + let extension = url_extension(&uri); let body = self.body.lock().await.to_string(); let range = params.range; @@ -278,93 +278,50 @@ impl LanguageServer for Backend { return Ok(None); }; - let Some((prompt, lang)) = desc.rsplit_once(" in ") else { + let (prompt, lang) = if let Some(ext) = extension { + (desc, ext) + } else if let Some((prompt, lang)) = desc.rsplit_once(" in ") { + (prompt, lang.to_string()) + } else { error!("{}", v2::errors::Error::MissingSuffix); return Ok(None); }; - let langfn = match v2::api::get_lang(lang) { - Ok(o) => o, - Err(e) => { - error!("{e}"); - return Ok(None); - } - }; - - info!(prompt = prompt, language = lang, "v2 request"); - - let mut appstate = self - .appstate - .inner - .lock() - .map_err(|_| v2::errors::Error::Busy) - .expect("booo"); - let target = appstate - .embed - .embed(prompt) - .map_err(|_| v2::errors::Error::EmbedFailed) - .expect("booo"); - let mut parser = TSParser::new(); - parser - .set_language(&langfn) - .map_err(|_| v2::errors::Error::UnknownLang) - .expect("boo"); - - let source_code = new_text; - let source_bytes = source_code.as_bytes(); - let tree = parser - .parse(source_code, None) - .ok_or(v2::errors::Error::SnippetParsing) - .expect("boo"); - let root_node = tree.root_node(); - - // search for k nearest neighbors - let closest: Vec = appstate.v2.dict[lang] - .search(&target, 1) - .iter() - .filter_map(|&index| { - let applied = v2::mutation::apply( - langfn.clone(), - source_bytes, - root_node, - &appstate.v2.mutations_collection[index], - ); - match applied { - Ok(v) => Some(v), - Err(e) => { - error!( - collection_index = index, - "failed to apply mutations from collection {}", e - ); - None - } + let closest_matches = + match v2::api::closest_mutation(&lang, prompt, &body, 1, &self.appstate) { + Ok(v) => v, + Err(e) => { + error!("{}", e); + return Ok(None); } - // TODO: change the expect to a log - }) - .collect(); - - let closest = closest[0].clone(); + }; + let Some(closest) = closest_matches.into_iter().next() else { + return Ok(None); + }; let text_edit = TextEdit { range, new_text: closest, }; let changes: HashMap = [(uri, vec![text_edit])].into_iter().collect(); - let edit = WorkspaceEdit { + let edit = Some(WorkspaceEdit { changes: Some(changes), document_changes: None, change_annotations: None, - }; + }); let actions = vec![CodeActionOrCommand::CodeAction(CodeAction { title: "ask silos".to_string(), - kind: None, - diagnostics: None, - edit: Some(edit), - command: None, - is_preferred: None, - disabled: None, - data: None, + edit, + ..Default::default() })]; Ok(Some(actions)) } } + +pub fn url_extension(u: &Url) -> Option { + let file_path = u.to_file_path().ok()?; + + let extension = file_path.extension()?; + let extension = extension.to_str()?; + Some(extension.to_string()) +} diff --git a/src/v2/api.rs b/src/v2/api.rs index cad0139..480a4af 100644 --- a/src/v2/api.rs +++ b/src/v2/api.rs @@ -52,6 +52,23 @@ pub(crate) async fn get_snippet( return Err(Error::MissingSuffix); }; + let closest = closest_mutation( + lang, + prompt, + snippet_request.body.as_str(), + snippet_request.top_k.unwrap_or(1), + &data, + )?; + Ok(web::Json(closest)) +} + +pub fn closest_mutation( + lang: &str, + prompt: &str, + body: &str, + top_k: usize, + data: &web::Data, +) -> Result, Error> { let langfn = get_lang(lang)?; info!(prompt = prompt, language = lang, "v2 request"); @@ -66,7 +83,7 @@ pub(crate) async fn get_snippet( .set_language(&langfn) .map_err(|_| Error::UnknownLang)?; - let source_code = snippet_request.body.as_str(); + let source_code = body; let source_bytes = source_code.as_bytes(); let tree = parser .parse(source_code, None) @@ -74,8 +91,8 @@ pub(crate) async fn get_snippet( let root_node = tree.root_node(); // search for k nearest neighbors - let closest: Vec = appstate.v2.dict[lang] - .search(&target, snippet_request.top_k.unwrap_or(1)) + let collected = appstate.v2.dict[lang] + .search(&target, top_k) .iter() .filter_map(|&index| { let applied = mutation::apply( @@ -97,5 +114,5 @@ pub(crate) async fn get_snippet( // TODO: change the expect to a log }) .collect(); - Ok(web::Json(closest)) + Ok(collected) } From efec8c2220e94c59d080065ad9c4142bb9bbc1dd Mon Sep 17 00:00:00 2001 From: Himadri Bhattacharjee <107522312+lavafroth@users.noreply.github.com> Date: Tue, 1 Jul 2025 19:07:07 +0530 Subject: [PATCH 4/6] refactor: decouple cli args and embed module --- src/args.rs | 56 ++++++++++++++++++++++++++++++++++++++++++++ src/embed.rs | 8 +++---- src/main.rs | 66 +++++----------------------------------------------- 3 files changed, 65 insertions(+), 65 deletions(-) create mode 100644 src/args.rs diff --git a/src/args.rs b/src/args.rs new file mode 100644 index 0000000..900c0d0 --- /dev/null +++ b/src/args.rs @@ -0,0 +1,56 @@ +use clap::Parser; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +pub(crate) struct Args { + /// The mode to run the server in. Defaults to LSP. The HTTP REST API can be run by specifying `http` or `http:port`. For example: `http:7047` + pub(crate) mode: Option, + + /// Run on the Nth GPU device. + #[arg(long)] + pub(crate) gpu: Option, + + /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending + #[arg(long)] + pub(crate) model_id: Option, + + /// Revision or branch. + #[arg(long)] + pub(crate) revision: Option, +} + +pub enum RunMode { + Http(u16), + Lsp, +} + +impl Args { + pub(crate) fn resolve_model_and_revision(&self) -> (String, String) { + let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string(); + let default_revision = "refs/pr/21".to_string(); + + match (self.model_id.clone(), self.revision.clone()) { + (Some(model_id), Some(revision)) => (model_id, revision), + (Some(model_id), None) => (model_id, "main".to_owned()), + (None, Some(revision)) => (default_model, revision), + (None, None) => (default_model, default_revision), + } + } + pub(crate) fn mode(&self) -> RunMode { + let Some(http) = &self.mode else { + return RunMode::Lsp; + }; + if http == "http" { + return RunMode::Http(8000); + } + let Some(port) = http.strip_prefix("http:") else { + return RunMode::Lsp; + }; + + let Ok(port) = port.parse() else { + return RunMode::Lsp; + }; + + RunMode::Http(port) + } +} diff --git a/src/embed.rs b/src/embed.rs index bdbbd61..1b9e104 100644 --- a/src/embed.rs +++ b/src/embed.rs @@ -1,4 +1,3 @@ -use super::Args; use anyhow::{Error as E, Result}; use candle_core::Device; use candle_core::Tensor; @@ -15,16 +14,15 @@ pub struct Embed { } impl Embed { - pub(crate) fn new(args: Args) -> Result { - let device = if let Some(gpu_dev) = args.gpu { + pub(crate) fn new(gpu: Option, model_id: &str, revision: &str) -> Result { + let device = if let Some(gpu_dev) = gpu { Device::new_cuda(gpu_dev)? } else { Device::Cpu }; - let (model_id, revision) = args.resolve_model_and_revision(); let (config_path, tokenizer_path, weights_path) = - Self::download_model_files(&model_id, &revision)?; + Self::download_model_files(model_id, revision)?; let config = std::fs::read_to_string(config_path)?; let config: Config = serde_json::from_str(&config)?; diff --git a/src/main.rs b/src/main.rs index 75a9f6c..a463074 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use actix_web::web::Data; use actix_web::{App, HttpServer, web}; use anyhow::{Context, Error as E, Result, bail}; @@ -9,71 +7,18 @@ use hora::index::hnsw_idx::HNSWIndex; use kdl::KdlDocument; use state::{State, StateWrapper}; use std::collections::HashMap; +use std::sync::Arc; use tokio::sync::Mutex; use tower_lsp::lsp_types::*; use tower_lsp::{Client, LanguageServer, LspService, Server}; use tracing::error; +mod args; mod embed; mod state; mod v1; mod v2; -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -struct Args { - /// The mode to run the server in. Defaults to LSP. The HTTP REST API can be run by specifying `http` or `http:port`. For example: `http:7047` - mode: Option, - - /// Run on the Nth GPU device. - #[arg(long)] - gpu: Option, - - /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending - #[arg(long)] - model_id: Option, - - /// Revision or branch. - #[arg(long)] - revision: Option, -} - -pub enum RunMode { - Http(u16), - Lsp, -} - -impl Args { - fn resolve_model_and_revision(&self) -> (String, String) { - let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string(); - let default_revision = "refs/pr/21".to_string(); - - match (self.model_id.clone(), self.revision.clone()) { - (Some(model_id), Some(revision)) => (model_id, revision), - (Some(model_id), None) => (model_id, "main".to_owned()), - (None, Some(revision)) => (default_model, revision), - (None, None) => (default_model, default_revision), - } - } - fn mode(&self) -> RunMode { - let Some(http) = &self.mode else { - return RunMode::Lsp; - }; - if http == "http" { - return RunMode::Http(8000); - } - let Some(port) = http.strip_prefix("http:") else { - return RunMode::Lsp; - }; - - let Ok(port) = port.parse() else { - return RunMode::Lsp; - }; - - RunMode::Http(port) - } -} - fn path_to_parent_base(p: &std::path::Path) -> Result { let Some(parent) = p .parent() @@ -89,9 +34,10 @@ fn path_to_parent_base(p: &std::path::Path) -> Result { #[actix_web::main] async fn main() -> Result<()> { tracing_subscriber::fmt::init(); - let args = Args::parse(); + let args = args::Args::parse(); let mode = args.mode(); - let mut embed = embed::Embed::new(args)?; + let (model_id, revision) = args.resolve_model_and_revision(); + let mut embed = embed::Embed::new(args.gpu, &model_id, &revision)?; let mut dict = HashMap::default(); let paths = glob::glob("./snippets/v1/*/*.kdl")?; @@ -167,7 +113,7 @@ async fn main() -> Result<()> { let appstate_wrapped = web::Data::new(appstate.build()); - if let RunMode::Http(port) = mode { + if let args::RunMode::Http(port) = mode { return HttpServer::new(move || { App::new() .app_data(appstate_wrapped.clone()) From 55e915cc32d9057b9f0c09352bb9ad6926fb6e3d Mon Sep 17 00:00:00 2001 From: Himadri Bhattacharjee <107522312+lavafroth@users.noreply.github.com> Date: Wed, 2 Jul 2025 10:15:52 +0530 Subject: [PATCH 5/6] feat: shard lsp module --- src/lsp.rs | 142 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 147 ++-------------------------------------------------- 2 files changed, 147 insertions(+), 142 deletions(-) create mode 100644 src/lsp.rs diff --git a/src/lsp.rs b/src/lsp.rs new file mode 100644 index 0000000..2f631c0 --- /dev/null +++ b/src/lsp.rs @@ -0,0 +1,142 @@ +use crate::StateWrapper; +use crate::v2; +use actix_web::web::Data; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex; +use tower_lsp::lsp_types::*; +use tower_lsp::{Client, LanguageServer}; +use tracing::error; + +pub struct Backend { + pub client: Client, + pub body: Arc>, + pub appstate: Data, +} + +pub fn string_range_index(s: &str, r: Range) -> &str { + let mut newline_count = 0; + let mut start = None; + let mut end = None; + for (i, c) in s.chars().enumerate() { + if newline_count == r.start.line && start.is_none() { + start.replace(i + r.start.character as usize); + } + + if newline_count == r.end.line && end.is_none() { + end.replace(i + r.end.character as usize); + } + if c == '\n' { + newline_count += 1; + } + } + &s[start.unwrap_or_default()..end.unwrap_or(s.len())] +} + +#[tower_lsp::async_trait] +impl LanguageServer for Backend { + async fn initialize( + &self, + _: InitializeParams, + ) -> tower_lsp::jsonrpc::Result { + Ok(InitializeResult { + capabilities: ServerCapabilities { + text_document_sync: Some(TextDocumentSyncCapability::Kind( + TextDocumentSyncKind::FULL, + )), + code_action_provider: Some( + tower_lsp::lsp_types::CodeActionProviderCapability::Options( + CodeActionOptions::default(), + ), + ), + ..Default::default() + }, + ..Default::default() + }) + } + + async fn initialized(&self, _: InitializedParams) { + self.client + .log_message(MessageType::INFO, "server initialized!") + .await; + } + + async fn shutdown(&self) -> tower_lsp::jsonrpc::Result<()> { + Ok(()) + } + + async fn did_open(&self, params: DidOpenTextDocumentParams) { + // TODO: build an index for multiple documents in workdir + *self.body.lock().await = params.text_document.text; + } + + async fn did_change(&self, params: DidChangeTextDocumentParams) { + if let Some(body) = params.content_changes.into_iter().next() { + *self.body.lock().await = body.text; + } + } + + async fn code_action( + &self, + params: CodeActionParams, + ) -> tower_lsp::jsonrpc::Result> { + let uri = params.text_document.uri; + let extension = url_extension(&uri); + let body = self.body.lock().await.to_string(); + + let range = params.range; + let new_text = string_range_index(&body, range); + let Some((_before, after)) = new_text.split_once("silos: ") else { + return Ok(None); + }; + let Some((desc, _after)) = after.split_once("\n") else { + return Ok(None); + }; + + let (prompt, lang) = if let Some(ext) = extension { + (desc, ext) + } else if let Some((prompt, lang)) = desc.rsplit_once(" in ") { + (prompt, lang.to_string()) + } else { + error!("{}", v2::errors::Error::MissingSuffix); + return Ok(None); + }; + + let closest_matches = + match v2::api::closest_mutation(&lang, prompt, &body, 1, &self.appstate) { + Ok(v) => v, + Err(e) => { + error!("{}", e); + return Ok(None); + } + }; + + let Some(closest) = closest_matches.into_iter().next() else { + return Ok(None); + }; + let text_edit = TextEdit { + range, + new_text: closest, + }; + let changes: HashMap = [(uri, vec![text_edit])].into_iter().collect(); + let edit = Some(WorkspaceEdit { + changes: Some(changes), + document_changes: None, + change_annotations: None, + }); + let actions = vec![CodeActionOrCommand::CodeAction(CodeAction { + title: "ask silos".to_string(), + edit, + ..Default::default() + })]; + Ok(Some(actions)) + } +} + +fn url_extension(u: &Url) -> Option { + let file_path = u.to_file_path().ok()?; + + let extension = file_path.extension()?; + let extension = extension.to_str()?; + Some(extension.to_string()) +} diff --git a/src/main.rs b/src/main.rs index a463074..7514998 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,20 +1,18 @@ -use actix_web::web::Data; use actix_web::{App, HttpServer, web}; use anyhow::{Context, Error as E, Result, bail}; use clap::Parser; -use hora::core::ann_index::ANNIndex; +use hora::core::{ann_index::ANNIndex, metrics::Metric::Euclidean}; use hora::index::hnsw_idx::HNSWIndex; use kdl::KdlDocument; use state::{State, StateWrapper}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::Mutex; -use tower_lsp::lsp_types::*; -use tower_lsp::{Client, LanguageServer, LspService, Server}; -use tracing::error; +use tower_lsp::{LspService, Server}; mod args; mod embed; +mod lsp; mod state; mod v1; mod v2; @@ -97,9 +95,7 @@ async fn main() -> Result<()> { } for index in v2_dict.values_mut() { - index - .build(hora::core::metrics::Metric::Euclidean) - .map_err(E::msg)?; + index.build(Euclidean).map_err(E::msg)?; } let appstate = State { @@ -130,7 +126,7 @@ async fn main() -> Result<()> { let stdin = tokio::io::stdin(); let stdout = tokio::io::stdout(); - let (service, socket) = LspService::new(|client| Backend { + let (service, socket) = LspService::new(|client| lsp::Backend { client, body: Arc::new(Mutex::new(String::default())), appstate: appstate_wrapped.clone(), @@ -138,136 +134,3 @@ async fn main() -> Result<()> { Server::new(stdin, stdout, socket).serve(service).await; Ok(()) } - -struct Backend { - client: Client, - body: Arc>, - appstate: Data, -} - -pub fn string_range_index(s: &str, r: Range) -> &str { - let mut newline_count = 0; - let mut start = None; - let mut end = None; - for (i, c) in s.chars().enumerate() { - if newline_count == r.start.line && start.is_none() { - start.replace(i + r.start.character as usize); - } - - if newline_count == r.end.line && end.is_none() { - end.replace(i + r.end.character as usize); - } - if c == '\n' { - newline_count += 1; - } - } - &s[start.unwrap_or_default()..end.unwrap_or(s.len())] -} - -#[tower_lsp::async_trait] -impl LanguageServer for Backend { - async fn initialize( - &self, - _: InitializeParams, - ) -> tower_lsp::jsonrpc::Result { - Ok(InitializeResult { - capabilities: ServerCapabilities { - text_document_sync: Some(TextDocumentSyncCapability::Kind( - TextDocumentSyncKind::FULL, - )), - code_action_provider: Some( - tower_lsp::lsp_types::CodeActionProviderCapability::Options( - CodeActionOptions::default(), - ), - ), - ..Default::default() - }, - ..Default::default() - }) - } - - async fn initialized(&self, _: InitializedParams) { - self.client - .log_message(MessageType::INFO, "server initialized!") - .await; - } - - async fn shutdown(&self) -> tower_lsp::jsonrpc::Result<()> { - Ok(()) - } - - async fn did_open(&self, params: DidOpenTextDocumentParams) { - // TODO: build an index for multiple documents in workdir - *self.body.lock().await = params.text_document.text; - } - - async fn did_change(&self, params: DidChangeTextDocumentParams) { - if let Some(body) = params.content_changes.into_iter().next() { - *self.body.lock().await = body.text; - } - } - - async fn code_action( - &self, - params: CodeActionParams, - ) -> tower_lsp::jsonrpc::Result> { - let uri = params.text_document.uri; - let extension = url_extension(&uri); - let body = self.body.lock().await.to_string(); - - let range = params.range; - let new_text = string_range_index(&body, range); - let Some((_before, after)) = new_text.split_once("silos: ") else { - return Ok(None); - }; - let Some((desc, _after)) = after.split_once("\n") else { - return Ok(None); - }; - - let (prompt, lang) = if let Some(ext) = extension { - (desc, ext) - } else if let Some((prompt, lang)) = desc.rsplit_once(" in ") { - (prompt, lang.to_string()) - } else { - error!("{}", v2::errors::Error::MissingSuffix); - return Ok(None); - }; - - let closest_matches = - match v2::api::closest_mutation(&lang, prompt, &body, 1, &self.appstate) { - Ok(v) => v, - Err(e) => { - error!("{}", e); - return Ok(None); - } - }; - - let Some(closest) = closest_matches.into_iter().next() else { - return Ok(None); - }; - let text_edit = TextEdit { - range, - new_text: closest, - }; - let changes: HashMap = [(uri, vec![text_edit])].into_iter().collect(); - let edit = Some(WorkspaceEdit { - changes: Some(changes), - document_changes: None, - change_annotations: None, - }); - let actions = vec![CodeActionOrCommand::CodeAction(CodeAction { - title: "ask silos".to_string(), - edit, - ..Default::default() - })]; - Ok(Some(actions)) - } -} - -pub fn url_extension(u: &Url) -> Option { - let file_path = u.to_file_path().ok()?; - - let extension = file_path.extension()?; - let extension = extension.to_str()?; - Some(extension.to_string()) -} From 864c394ed74705743d78255c274427aaa6cc23b0 Mon Sep 17 00:00:00 2001 From: Himadri Bhattacharjee <107522312+lavafroth@users.noreply.github.com> Date: Wed, 2 Jul 2025 10:19:38 +0530 Subject: [PATCH 6/6] ver: 2.0.0 lsp mode introduces breaking changes --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 8cd3a24..622f8d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "silos" -version = "1.1.0" +version = "2.0.0" edition = "2024" [dependencies]