ver: 2.0.0 merge branch 'lsp'
This commit is contained in:
6
.helix/languages.toml
Normal file
6
.helix/languages.toml
Normal file
@@ -0,0 +1,6 @@
|
||||
[language-server.silos]
|
||||
command = "./target/debug/silos"
|
||||
|
||||
[[language]]
|
||||
name = "go"
|
||||
language-servers = [ { name = "silos" } ]
|
||||
144
Cargo.lock
generated
144
Cargo.lock
generated
@@ -289,12 +289,34 @@ dependencies = [
|
||||
"derive_arbitrary",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.88"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-waker"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
|
||||
|
||||
[[package]]
|
||||
name = "auto_impl"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ffdcb70bdbc4d478427380519163274ac86e52916e10f0a8889adf0f96d3fee7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.4.0"
|
||||
@@ -694,6 +716,19 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "5.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"hashbrown 0.14.5",
|
||||
"lock_api",
|
||||
"once_cell",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deranged"
|
||||
version = "0.4.0"
|
||||
@@ -1377,6 +1412,12 @@ dependencies = [
|
||||
"rand_distr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.14.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.15.3"
|
||||
@@ -1695,7 +1736,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown",
|
||||
"hashbrown 0.15.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1878,6 +1919,19 @@ version = "0.4.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
|
||||
|
||||
[[package]]
|
||||
name = "lsp-types"
|
||||
version = "0.94.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c66bfd44a06ae10647fe3f8214762e9369fd4248df1350924b4ef9e770a85ea1"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_repr",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "macro_rules_attribute"
|
||||
version = "0.2.2"
|
||||
@@ -2275,6 +2329,26 @@ version = "2.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
|
||||
|
||||
[[package]]
|
||||
name = "pin-project"
|
||||
version = "1.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a"
|
||||
dependencies = [
|
||||
"pin-project-internal",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-internal"
|
||||
version = "1.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.16"
|
||||
@@ -2596,7 +2670,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-util",
|
||||
"tower",
|
||||
"tower 0.5.2",
|
||||
"tower-http",
|
||||
"tower-service",
|
||||
"url",
|
||||
@@ -2791,6 +2865,17 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_repr"
|
||||
version = "0.1.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_urlencoded"
|
||||
version = "0.7.1"
|
||||
@@ -2840,7 +2925,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "silos"
|
||||
version = "1.0.0"
|
||||
version = "1.1.0"
|
||||
dependencies = [
|
||||
"actix-web",
|
||||
"anyhow",
|
||||
@@ -2856,6 +2941,8 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tower-lsp",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"tree-sitter",
|
||||
@@ -3230,6 +3317,20 @@ dependencies = [
|
||||
"winnow 0.7.10",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower"
|
||||
version = "0.4.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"pin-project",
|
||||
"pin-project-lite",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower"
|
||||
version = "0.5.2"
|
||||
@@ -3258,7 +3359,7 @@ dependencies = [
|
||||
"http-body",
|
||||
"iri-string",
|
||||
"pin-project-lite",
|
||||
"tower",
|
||||
"tower 0.5.2",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
@@ -3269,6 +3370,40 @@ version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
|
||||
|
||||
[[package]]
|
||||
name = "tower-lsp"
|
||||
version = "0.20.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d4ba052b54a6627628d9b3c34c176e7eda8359b7da9acd497b9f20998d118508"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"auto_impl",
|
||||
"bytes",
|
||||
"dashmap",
|
||||
"futures",
|
||||
"httparse",
|
||||
"lsp-types",
|
||||
"memchr",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tower 0.4.13",
|
||||
"tower-lsp-macros",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-lsp-macros"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "84fd902d4e0b9a4b27f2f440108dc034e1758628a9b702f8ec61ad66355422fa"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-service"
|
||||
version = "0.3.3"
|
||||
@@ -3496,6 +3631,7 @@ dependencies = [
|
||||
"form_urlencoded",
|
||||
"idna",
|
||||
"percent-encoding",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "silos"
|
||||
version = "1.0.0"
|
||||
version = "2.0.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
@@ -24,3 +24,5 @@ tree-sitter = "0.25.6"
|
||||
tree-sitter-c = "0.24.1"
|
||||
tree-sitter-go = "0.23.4"
|
||||
tree-sitter-rust = "0.24.0"
|
||||
tokio = { version = "1.45.1", features = ["io-std", "macros", "rt", "rt-multi-thread"] }
|
||||
tower-lsp = "0.20.0"
|
||||
|
||||
56
src/args.rs
Normal file
56
src/args.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub(crate) struct Args {
|
||||
/// The mode to run the server in. Defaults to LSP. The HTTP REST API can be run by specifying `http` or `http:port`. For example: `http:7047`
|
||||
pub(crate) mode: Option<String>,
|
||||
|
||||
/// Run on the Nth GPU device.
|
||||
#[arg(long)]
|
||||
pub(crate) gpu: Option<usize>,
|
||||
|
||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||
#[arg(long)]
|
||||
pub(crate) model_id: Option<String>,
|
||||
|
||||
/// Revision or branch.
|
||||
#[arg(long)]
|
||||
pub(crate) revision: Option<String>,
|
||||
}
|
||||
|
||||
pub enum RunMode {
|
||||
Http(u16),
|
||||
Lsp,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
pub(crate) fn resolve_model_and_revision(&self) -> (String, String) {
|
||||
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
|
||||
let default_revision = "refs/pr/21".to_string();
|
||||
|
||||
match (self.model_id.clone(), self.revision.clone()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_owned()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, default_revision),
|
||||
}
|
||||
}
|
||||
pub(crate) fn mode(&self) -> RunMode {
|
||||
let Some(http) = &self.mode else {
|
||||
return RunMode::Lsp;
|
||||
};
|
||||
if http == "http" {
|
||||
return RunMode::Http(8000);
|
||||
}
|
||||
let Some(port) = http.strip_prefix("http:") else {
|
||||
return RunMode::Lsp;
|
||||
};
|
||||
|
||||
let Ok(port) = port.parse() else {
|
||||
return RunMode::Lsp;
|
||||
};
|
||||
|
||||
RunMode::Http(port)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
use super::Args;
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle_core::Device;
|
||||
use candle_core::Tensor;
|
||||
@@ -15,16 +14,15 @@ pub struct Embed {
|
||||
}
|
||||
|
||||
impl Embed {
|
||||
pub(crate) fn new(args: Args) -> Result<Self> {
|
||||
let device = if let Some(gpu_dev) = args.gpu {
|
||||
pub(crate) fn new(gpu: Option<usize>, model_id: &str, revision: &str) -> Result<Self> {
|
||||
let device = if let Some(gpu_dev) = gpu {
|
||||
Device::new_cuda(gpu_dev)?
|
||||
} else {
|
||||
Device::Cpu
|
||||
};
|
||||
|
||||
let (model_id, revision) = args.resolve_model_and_revision();
|
||||
let (config_path, tokenizer_path, weights_path) =
|
||||
Self::download_model_files(&model_id, &revision)?;
|
||||
Self::download_model_files(model_id, revision)?;
|
||||
|
||||
let config = std::fs::read_to_string(config_path)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
|
||||
142
src/lsp.rs
Normal file
142
src/lsp.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use crate::StateWrapper;
|
||||
use crate::v2;
|
||||
use actix_web::web::Data;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tower_lsp::lsp_types::*;
|
||||
use tower_lsp::{Client, LanguageServer};
|
||||
use tracing::error;
|
||||
|
||||
pub struct Backend {
|
||||
pub client: Client,
|
||||
pub body: Arc<Mutex<String>>,
|
||||
pub appstate: Data<StateWrapper>,
|
||||
}
|
||||
|
||||
pub fn string_range_index(s: &str, r: Range) -> &str {
|
||||
let mut newline_count = 0;
|
||||
let mut start = None;
|
||||
let mut end = None;
|
||||
for (i, c) in s.chars().enumerate() {
|
||||
if newline_count == r.start.line && start.is_none() {
|
||||
start.replace(i + r.start.character as usize);
|
||||
}
|
||||
|
||||
if newline_count == r.end.line && end.is_none() {
|
||||
end.replace(i + r.end.character as usize);
|
||||
}
|
||||
if c == '\n' {
|
||||
newline_count += 1;
|
||||
}
|
||||
}
|
||||
&s[start.unwrap_or_default()..end.unwrap_or(s.len())]
|
||||
}
|
||||
|
||||
#[tower_lsp::async_trait]
|
||||
impl LanguageServer for Backend {
|
||||
async fn initialize(
|
||||
&self,
|
||||
_: InitializeParams,
|
||||
) -> tower_lsp::jsonrpc::Result<InitializeResult> {
|
||||
Ok(InitializeResult {
|
||||
capabilities: ServerCapabilities {
|
||||
text_document_sync: Some(TextDocumentSyncCapability::Kind(
|
||||
TextDocumentSyncKind::FULL,
|
||||
)),
|
||||
code_action_provider: Some(
|
||||
tower_lsp::lsp_types::CodeActionProviderCapability::Options(
|
||||
CodeActionOptions::default(),
|
||||
),
|
||||
),
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
async fn initialized(&self, _: InitializedParams) {
|
||||
self.client
|
||||
.log_message(MessageType::INFO, "server initialized!")
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn shutdown(&self) -> tower_lsp::jsonrpc::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn did_open(&self, params: DidOpenTextDocumentParams) {
|
||||
// TODO: build an index for multiple documents in workdir
|
||||
*self.body.lock().await = params.text_document.text;
|
||||
}
|
||||
|
||||
async fn did_change(&self, params: DidChangeTextDocumentParams) {
|
||||
if let Some(body) = params.content_changes.into_iter().next() {
|
||||
*self.body.lock().await = body.text;
|
||||
}
|
||||
}
|
||||
|
||||
async fn code_action(
|
||||
&self,
|
||||
params: CodeActionParams,
|
||||
) -> tower_lsp::jsonrpc::Result<Option<CodeActionResponse>> {
|
||||
let uri = params.text_document.uri;
|
||||
let extension = url_extension(&uri);
|
||||
let body = self.body.lock().await.to_string();
|
||||
|
||||
let range = params.range;
|
||||
let new_text = string_range_index(&body, range);
|
||||
let Some((_before, after)) = new_text.split_once("silos: ") else {
|
||||
return Ok(None);
|
||||
};
|
||||
let Some((desc, _after)) = after.split_once("\n") else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let (prompt, lang) = if let Some(ext) = extension {
|
||||
(desc, ext)
|
||||
} else if let Some((prompt, lang)) = desc.rsplit_once(" in ") {
|
||||
(prompt, lang.to_string())
|
||||
} else {
|
||||
error!("{}", v2::errors::Error::MissingSuffix);
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let closest_matches =
|
||||
match v2::api::closest_mutation(&lang, prompt, &body, 1, &self.appstate) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
error!("{}", e);
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
let Some(closest) = closest_matches.into_iter().next() else {
|
||||
return Ok(None);
|
||||
};
|
||||
let text_edit = TextEdit {
|
||||
range,
|
||||
new_text: closest,
|
||||
};
|
||||
let changes: HashMap<Url, _> = [(uri, vec![text_edit])].into_iter().collect();
|
||||
let edit = Some(WorkspaceEdit {
|
||||
changes: Some(changes),
|
||||
document_changes: None,
|
||||
change_annotations: None,
|
||||
});
|
||||
let actions = vec![CodeActionOrCommand::CodeAction(CodeAction {
|
||||
title: "ask silos".to_string(),
|
||||
edit,
|
||||
..Default::default()
|
||||
})];
|
||||
Ok(Some(actions))
|
||||
}
|
||||
}
|
||||
|
||||
fn url_extension(u: &Url) -> Option<String> {
|
||||
let file_path = u.to_file_path().ok()?;
|
||||
|
||||
let extension = file_path.extension()?;
|
||||
let extension = extension.to_str()?;
|
||||
Some(extension.to_string())
|
||||
}
|
||||
89
src/main.rs
89
src/main.rs
@@ -1,51 +1,22 @@
|
||||
use actix_web::{App, HttpServer, web};
|
||||
use anyhow::{Context, Error as E, Result, bail};
|
||||
use clap::Parser;
|
||||
use hora::core::ann_index::ANNIndex;
|
||||
use hora::core::{ann_index::ANNIndex, metrics::Metric::Euclidean};
|
||||
use hora::index::hnsw_idx::HNSWIndex;
|
||||
use kdl::KdlDocument;
|
||||
use state::State;
|
||||
use state::{State, StateWrapper};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tower_lsp::{LspService, Server};
|
||||
|
||||
mod args;
|
||||
mod embed;
|
||||
mod lsp;
|
||||
mod state;
|
||||
mod v1;
|
||||
mod v2;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on the Nth GPU device.
|
||||
#[arg(long)]
|
||||
gpu: Option<usize>,
|
||||
|
||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
/// Revision or branch.
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// The port for the API to listen on
|
||||
#[arg(long, default_value = "8000")]
|
||||
port: u16,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn resolve_model_and_revision(&self) -> (String, String) {
|
||||
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
|
||||
let default_revision = "refs/pr/21".to_string();
|
||||
|
||||
match (self.model_id.clone(), self.revision.clone()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_owned()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, default_revision),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn path_to_parent_base(p: &std::path::Path) -> Result<String> {
|
||||
let Some(parent) = p
|
||||
.parent()
|
||||
@@ -61,9 +32,10 @@ fn path_to_parent_base(p: &std::path::Path) -> Result<String> {
|
||||
#[actix_web::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
let args = Args::parse();
|
||||
let port = args.port;
|
||||
let mut embed = embed::Embed::new(args)?;
|
||||
let args = args::Args::parse();
|
||||
let mode = args.mode();
|
||||
let (model_id, revision) = args.resolve_model_and_revision();
|
||||
let mut embed = embed::Embed::new(args.gpu, &model_id, &revision)?;
|
||||
let mut dict = HashMap::default();
|
||||
|
||||
let paths = glob::glob("./snippets/v1/*/*.kdl")?;
|
||||
@@ -123,9 +95,7 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
for index in v2_dict.values_mut() {
|
||||
index
|
||||
.build(hora::core::metrics::Metric::Euclidean)
|
||||
.map_err(E::msg)?;
|
||||
index.build(Euclidean).map_err(E::msg)?;
|
||||
}
|
||||
|
||||
let appstate = State {
|
||||
@@ -139,15 +109,28 @@ async fn main() -> Result<()> {
|
||||
|
||||
let appstate_wrapped = web::Data::new(appstate.build());
|
||||
|
||||
HttpServer::new(move || {
|
||||
App::new()
|
||||
.app_data(appstate_wrapped.clone())
|
||||
.service(v1::api::get_snippet)
|
||||
.service(v1::api::add_snippet)
|
||||
.service(v2::api::get_snippet)
|
||||
})
|
||||
.bind(("127.0.0.1", port))?
|
||||
.run()
|
||||
.await
|
||||
.map_err(E::from)
|
||||
if let args::RunMode::Http(port) = mode {
|
||||
return HttpServer::new(move || {
|
||||
App::new()
|
||||
.app_data(appstate_wrapped.clone())
|
||||
.service(v1::api::get_snippet)
|
||||
.service(v1::api::add_snippet)
|
||||
.service(v2::api::get_snippet)
|
||||
})
|
||||
.bind(("127.0.0.1", port))?
|
||||
.run()
|
||||
.await
|
||||
.map_err(E::from);
|
||||
};
|
||||
|
||||
let stdin = tokio::io::stdin();
|
||||
let stdout = tokio::io::stdout();
|
||||
|
||||
let (service, socket) = LspService::new(|client| lsp::Backend {
|
||||
client,
|
||||
body: Arc::new(Mutex::new(String::default())),
|
||||
appstate: appstate_wrapped.clone(),
|
||||
});
|
||||
Server::new(stdin, stdout, socket).serve(service).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ pub struct Snippet {
|
||||
body: String,
|
||||
}
|
||||
|
||||
fn get_lang(s: &str) -> Result<tree_sitter::Language, Error> {
|
||||
pub fn get_lang(s: &str) -> Result<tree_sitter::Language, Error> {
|
||||
Ok(match s {
|
||||
"go" => tree_sitter_go::LANGUAGE,
|
||||
"c" => tree_sitter_c::LANGUAGE,
|
||||
@@ -52,6 +52,23 @@ pub(crate) async fn get_snippet(
|
||||
return Err(Error::MissingSuffix);
|
||||
};
|
||||
|
||||
let closest = closest_mutation(
|
||||
lang,
|
||||
prompt,
|
||||
snippet_request.body.as_str(),
|
||||
snippet_request.top_k.unwrap_or(1),
|
||||
&data,
|
||||
)?;
|
||||
Ok(web::Json(closest))
|
||||
}
|
||||
|
||||
pub fn closest_mutation(
|
||||
lang: &str,
|
||||
prompt: &str,
|
||||
body: &str,
|
||||
top_k: usize,
|
||||
data: &web::Data<crate::state::StateWrapper>,
|
||||
) -> Result<Vec<String>, Error> {
|
||||
let langfn = get_lang(lang)?;
|
||||
|
||||
info!(prompt = prompt, language = lang, "v2 request");
|
||||
@@ -66,7 +83,7 @@ pub(crate) async fn get_snippet(
|
||||
.set_language(&langfn)
|
||||
.map_err(|_| Error::UnknownLang)?;
|
||||
|
||||
let source_code = snippet_request.body.as_str();
|
||||
let source_code = body;
|
||||
let source_bytes = source_code.as_bytes();
|
||||
let tree = parser
|
||||
.parse(source_code, None)
|
||||
@@ -74,8 +91,8 @@ pub(crate) async fn get_snippet(
|
||||
let root_node = tree.root_node();
|
||||
|
||||
// search for k nearest neighbors
|
||||
let closest: Vec<String> = appstate.v2.dict[lang]
|
||||
.search(&target, snippet_request.top_k.unwrap_or(1))
|
||||
let collected = appstate.v2.dict[lang]
|
||||
.search(&target, top_k)
|
||||
.iter()
|
||||
.filter_map(|&index| {
|
||||
let applied = mutation::apply(
|
||||
@@ -97,5 +114,5 @@ pub(crate) async fn get_snippet(
|
||||
// TODO: change the expect to a log
|
||||
})
|
||||
.collect();
|
||||
Ok(web::Json(closest))
|
||||
Ok(collected)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user