Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e24e62873f | ||
|
|
0b9ab89f35 | ||
|
|
650329206d | ||
|
|
633c1a206b | ||
|
|
ab4c62fcf4 | ||
|
|
6ff9ba9d16 | ||
|
|
d359121afd | ||
|
|
4abd2cffac | ||
|
|
daccd63006 | ||
|
|
87e096f0bc | ||
|
|
91d2640c11 | ||
|
|
ec3b89f455 | ||
|
|
c734c81a04 | ||
|
|
e7cae348a1 | ||
|
|
faea784d8f | ||
|
|
e8970f21ff | ||
|
|
a1445b2f03 | ||
|
|
8f5e618841 | ||
|
|
996142c8dd |
8
.github/dependabot.yml
vendored
Normal file
8
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "cargo" # See documentation for possible values
|
||||
directory: "/" # Location of package manifests
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 4
|
||||
|
||||
26
.github/workflows/release.yml
vendored
Normal file
26
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
on:
|
||||
release:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: release ${{ matrix.target }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- target: x86_64-pc-windows-gnu
|
||||
archive: zip
|
||||
- target: x86_64-unknown-linux-musl
|
||||
archive: tar.zst
|
||||
steps:
|
||||
- uses: actions/checkout@master
|
||||
- name: Compile and release
|
||||
uses: rust-build/rust-build.action@v1.4.5
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
RUSTTARGET: ${{ matrix.target }}
|
||||
ARCHIVE_TYPES: ${{ matrix.archive }}
|
||||
TOOLCHAIN_VERSION: stable
|
||||
22
.github/workflows/rust.yml
vendored
Normal file
22
.github/workflows/rust.yml
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
name: Build and test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "master" ]
|
||||
pull_request:
|
||||
branches: [ "master" ]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Build
|
||||
run: cargo build --verbose
|
||||
- name: Run tests
|
||||
run: cargo test --verbose
|
||||
48
Cargo.lock
generated
48
Cargo.lock
generated
@@ -299,9 +299,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.39"
|
||||
version = "4.5.41"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd60e63e9be68e5fb56422e397cf9baddded06dae1d2e523401542383bc72a9f"
|
||||
checksum = "be92d32e80243a54711e5d7ce823c35c41c9d929dc4ab58e1276f625841aadf9"
|
||||
dependencies = [
|
||||
"clap_builder",
|
||||
"clap_derive",
|
||||
@@ -309,9 +309,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "clap_builder"
|
||||
version = "4.5.39"
|
||||
version = "4.5.41"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89cc6392a1f72bbeb820d71f32108f61fdaf18bc526e1d23954168a67759ef51"
|
||||
checksum = "707eab41e9622f9139419d573eca0900137718000c517d47da73045f54331c3d"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
@@ -321,9 +321,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "clap_derive"
|
||||
version = "4.5.32"
|
||||
version = "4.5.41"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7"
|
||||
checksum = "ef4f52386a59ca4c860f7393bcf8abd8dfd91ecccc0f774635ff68e92eeef491"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
@@ -1050,12 +1050,6 @@ version = "0.31.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.10"
|
||||
@@ -2497,7 +2491,7 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
||||
|
||||
[[package]]
|
||||
name = "silos"
|
||||
version = "3.0.0"
|
||||
version = "5.2.2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"candle-core",
|
||||
@@ -2505,11 +2499,9 @@ dependencies = [
|
||||
"candle-transformers",
|
||||
"clap",
|
||||
"derive_more",
|
||||
"glob",
|
||||
"hf-hub",
|
||||
"hora",
|
||||
"kdl",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
@@ -2518,7 +2510,9 @@ dependencies = [
|
||||
"tracing-subscriber",
|
||||
"tree-sitter",
|
||||
"tree-sitter-c",
|
||||
"tree-sitter-cpp",
|
||||
"tree-sitter-go",
|
||||
"tree-sitter-javascript",
|
||||
"tree-sitter-rust",
|
||||
]
|
||||
|
||||
@@ -3007,9 +3001,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter"
|
||||
version = "0.25.6"
|
||||
version = "0.25.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a7cf18d43cbf0bfca51f657132cc616a5097edc4424d538bae6fa60142eaf9f0"
|
||||
checksum = "6d7b8994f367f16e6fa14b5aebbcb350de5d7cbea82dc5b00ae997dd71680dd2"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"regex",
|
||||
@@ -3029,6 +3023,16 @@ dependencies = [
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-cpp"
|
||||
version = "0.23.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df2196ea9d47b4ab4a31b9297eaa5a5d19a0b121dceb9f118f6790ad0ab94743"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-go"
|
||||
version = "0.23.4"
|
||||
@@ -3039,6 +3043,16 @@ dependencies = [
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-javascript"
|
||||
version = "0.23.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bf40bf599e0416c16c125c3cec10ee5ddc7d1bb8b0c60fa5c4de249ad34dc1b1"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-language"
|
||||
version = "0.1.5"
|
||||
|
||||
10
Cargo.toml
10
Cargo.toml
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "silos"
|
||||
version = "4.0.0"
|
||||
version = "5.2.2"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
@@ -8,20 +8,20 @@ anyhow = "1.0.98"
|
||||
candle-core = "0.9.1"
|
||||
candle-nn = "0.9.1"
|
||||
candle-transformers = "0.9.1"
|
||||
clap = { version = "4.5.39", features = ["derive"] }
|
||||
clap = { version = "4.5.41", features = ["derive"] }
|
||||
derive_more = { version = "2.0.1", features = ["display", "error"] }
|
||||
glob = "0.3.2"
|
||||
hf-hub = "0.4.2"
|
||||
hora = "0.1.1"
|
||||
kdl = "6.3.4"
|
||||
serde = "1.0.219"
|
||||
serde_json = "1.0.140"
|
||||
tokenizers = "0.21.1"
|
||||
tracing = "0.1.41"
|
||||
tracing-subscriber = "0.3.19"
|
||||
tree-sitter = "0.25.6"
|
||||
tree-sitter = "0.25.8"
|
||||
tree-sitter-c = "0.24.1"
|
||||
tree-sitter-go = "0.23.4"
|
||||
tree-sitter-rust = "0.24.0"
|
||||
tokio = { version = "1.45.1", features = ["io-std", "macros", "rt", "rt-multi-thread"] }
|
||||
tower-lsp = "0.20.0"
|
||||
tree-sitter-javascript = "0.23.1"
|
||||
tree-sitter-cpp = "0.23.4"
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
Dumb, proomptable modular snippet search.
|
||||
|
||||

|
||||
|
||||
## Installation
|
||||
|
||||
### Binary releases
|
||||
@@ -76,6 +78,8 @@ This API parses code into an AST (Abstract Syntax Tree) via tree-sitter and can
|
||||
- C
|
||||
- Rust
|
||||
- Go
|
||||
- Javascript
|
||||
- C++
|
||||
|
||||
### Defining mutation collections
|
||||
|
||||
|
||||
BIN
assets/preview.gif
Normal file
BIN
assets/preview.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.9 MiB |
6
flake.lock
generated
6
flake.lock
generated
@@ -2,11 +2,11 @@
|
||||
"nodes": {
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1749776303,
|
||||
"narHash": "sha256-OHibOvVwKqO1qvRg0r3agtd1EagW4THBcoWT7QGgcNo=",
|
||||
"lastModified": 1755020227,
|
||||
"narHash": "sha256-gGmm+h0t6rY88RPTaIm3su95QvQIVjAJx558YUG4Id8=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "6e7721e37bf00fa7ea44ac3cfc9d2411284ec3ef",
|
||||
"rev": "695d5db1b8b20b73292501683a524e0bd79074fb",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
@@ -14,6 +14,10 @@ pub(crate) struct Args {
|
||||
/// Revision or branch.
|
||||
#[arg(long)]
|
||||
pub(crate) revision: Option<String>,
|
||||
|
||||
/// Path to the directory containing `generate` and `refactor` snippets.
|
||||
#[arg(long, default_value = "./snippets")]
|
||||
pub(crate) snippets: std::path::PathBuf,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
|
||||
35
src/embed.rs
35
src/embed.rs
@@ -7,17 +7,24 @@ use hf_hub::Repo;
|
||||
use hf_hub::RepoType;
|
||||
use hf_hub::api::sync::Api;
|
||||
use std::path::PathBuf;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokenizers::TokenizerImpl;
|
||||
use tokenizers::DecoderWrapper;
|
||||
use tokenizers::ModelWrapper;
|
||||
use tokenizers::NormalizerWrapper;
|
||||
use tokenizers::PreTokenizerWrapper;
|
||||
use tokenizers::PostProcessorWrapper;
|
||||
use tokenizers::DecoderWrapper;
|
||||
use tokenizers::PreTokenizerWrapper;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokenizers::TokenizerImpl;
|
||||
|
||||
pub struct Embed {
|
||||
model: BertModel,
|
||||
tokenizer: TokenizerImpl<ModelWrapper, NormalizerWrapper, PreTokenizerWrapper, PostProcessorWrapper, DecoderWrapper>,
|
||||
pub hidden_size: usize,
|
||||
tokenizer: TokenizerImpl<
|
||||
ModelWrapper,
|
||||
NormalizerWrapper,
|
||||
PreTokenizerWrapper,
|
||||
PostProcessorWrapper,
|
||||
DecoderWrapper,
|
||||
>,
|
||||
}
|
||||
|
||||
impl Embed {
|
||||
@@ -41,9 +48,14 @@ impl Embed {
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?.clone();
|
||||
.map_err(E::msg)?
|
||||
.clone();
|
||||
|
||||
Ok(Embed { model, tokenizer })
|
||||
Ok(Embed {
|
||||
model,
|
||||
tokenizer,
|
||||
hidden_size: config.hidden_size,
|
||||
})
|
||||
}
|
||||
|
||||
fn download_model_files(model_id: &str, revision: &str) -> Result<(PathBuf, PathBuf, PathBuf)> {
|
||||
@@ -58,7 +70,8 @@ impl Embed {
|
||||
}
|
||||
|
||||
pub(crate) fn embed(&self, prompt: &str) -> Result<Vec<f32>> {
|
||||
let tokens = self.tokenizer
|
||||
let tokens = self
|
||||
.tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
@@ -68,9 +81,9 @@ impl Embed {
|
||||
let token_type_ids = token_ids.zeros_like()?;
|
||||
|
||||
let embeddings = self.model.forward(&token_ids, &token_type_ids, None)?;
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
let embeddings = normalize_l2(&embeddings)?.reshape(384)?.to_vec1::<f32>()?;
|
||||
let embeddings = normalize_l2(&embeddings.sum(1)?)?
|
||||
.reshape(self.hidden_size)?
|
||||
.to_vec1::<f32>()?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
30
src/lsp.rs
30
src/lsp.rs
@@ -6,7 +6,7 @@ use tower_lsp::{Client, LanguageServer};
|
||||
|
||||
pub struct Backend {
|
||||
pub client: Client,
|
||||
pub body: Arc<Mutex<String>>,
|
||||
pub body: Arc<Mutex<HashMap<Url, String>>>,
|
||||
pub appstate: crate::State,
|
||||
}
|
||||
|
||||
@@ -60,13 +60,12 @@ impl LanguageServer for Backend {
|
||||
}
|
||||
|
||||
async fn did_open(&self, params: DidOpenTextDocumentParams) {
|
||||
// TODO: build an index for multiple documents in workdir
|
||||
*self.body.lock().await = params.text_document.text;
|
||||
self.body.lock().await.insert(params.text_document.uri, params.text_document.text);
|
||||
}
|
||||
|
||||
async fn did_change(&self, params: DidChangeTextDocumentParams) {
|
||||
if let Some(body) = params.content_changes.into_iter().next() {
|
||||
*self.body.lock().await = body.text;
|
||||
self.body.lock().await.insert(params.text_document.uri, body.text);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,14 +76,20 @@ impl LanguageServer for Backend {
|
||||
let uri = params.text_document.uri;
|
||||
let Some(lang) = url_extension(&uri) else {
|
||||
self.client
|
||||
.log_message(MessageType::ERROR, "unable to determine filetype, file has no extension")
|
||||
.log_message(
|
||||
MessageType::ERROR,
|
||||
"unable to determine filetype, file has no extension",
|
||||
)
|
||||
.await;
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let body = self.body.lock().await.to_string();
|
||||
let body_locked = self.body.lock().await;
|
||||
let Some(body) = body_locked.get(&uri) else {
|
||||
return Ok(None);
|
||||
};
|
||||
let mut range = params.range;
|
||||
let selected_text = string_range_index(&body, range);
|
||||
let selected_text = string_range_index(body, range);
|
||||
|
||||
let Some(comment) = ParsedAction::new(selected_text) else {
|
||||
return Ok(None);
|
||||
@@ -93,14 +98,15 @@ impl LanguageServer for Backend {
|
||||
let action_response = match comment.action {
|
||||
Action::Generate => {
|
||||
range.start = range.end;
|
||||
self.appstate.generate(&lang, comment.description, 1)
|
||||
self.appstate
|
||||
.generate(&lang, comment.description, 1)
|
||||
.map(|v| v.into_iter().map(|s| format!("{s}\n")).collect())
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
Action::Refactor => {
|
||||
self.appstate.refactor(&lang, comment.description, selected_text, 1)
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
Action::Refactor => self
|
||||
.appstate
|
||||
.refactor(&lang, comment.description, selected_text, 1)
|
||||
.map_err(|e| e.to_string()),
|
||||
};
|
||||
|
||||
let closest_matches = match action_response {
|
||||
|
||||
106
src/main.rs
106
src/main.rs
@@ -1,4 +1,4 @@
|
||||
use anyhow::{Context, Error as E, Result, bail};
|
||||
use anyhow::{Context, Error as E, Result};
|
||||
use clap::Parser;
|
||||
use hora::core::{ann_index::ANNIndex, metrics::Metric::Euclidean};
|
||||
use hora::index::hnsw_idx::HNSWIndex;
|
||||
@@ -12,20 +12,9 @@ use tower_lsp::{LspService, Server};
|
||||
mod args;
|
||||
mod embed;
|
||||
mod lsp;
|
||||
mod state;
|
||||
mod mutation;
|
||||
|
||||
fn path_to_parent_base(p: &std::path::Path) -> Result<String> {
|
||||
let Some(parent) = p
|
||||
.parent()
|
||||
.and_then(|v| v.file_name())
|
||||
.and_then(|v| v.to_str())
|
||||
.map(|v| v.to_string())
|
||||
else {
|
||||
bail!("failed to parse snippets path, make sure the directory structure is valid");
|
||||
};
|
||||
Ok(parent)
|
||||
}
|
||||
mod state;
|
||||
mod sources;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
@@ -34,31 +23,29 @@ async fn main() -> Result<()> {
|
||||
let (model_id, revision) = args.resolve_model_and_revision();
|
||||
let embed = embed::Embed::new(args.gpu, &model_id, &revision)?;
|
||||
let mut dict = HashMap::default();
|
||||
let dimensions = 384;
|
||||
let dimensions = embed.hidden_size;
|
||||
|
||||
let paths = glob::glob("./snippets/v1/*/*.kdl")?;
|
||||
for path in paths {
|
||||
let path = path?;
|
||||
let parent = path_to_parent_base(&path)?;
|
||||
for (language, paths) in sources::rule_files(args.snippets.join("generate"))? {
|
||||
for path in paths {
|
||||
let current_lang_index = dict
|
||||
.entry(language.clone())
|
||||
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
|
||||
|
||||
let current_lang_index = dict
|
||||
.entry(parent)
|
||||
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
|
||||
let doc_str = std::fs::read_to_string(&path)?;
|
||||
let doc: KdlDocument = doc_str
|
||||
.parse()
|
||||
.context(format!("failed to parse KDL: {}", path.display()))?;
|
||||
|
||||
let doc_str = std::fs::read_to_string(&path)?;
|
||||
let doc: KdlDocument = doc_str
|
||||
.parse()
|
||||
.context(format!("failed to parse KDL: {}", path.display()))?;
|
||||
|
||||
let Some(desc) = doc.get_arg("desc").and_then(|v| v.as_string()) else {
|
||||
continue;
|
||||
};
|
||||
let Some(body) = doc.get_arg("body").and_then(|v| v.as_string()) else {
|
||||
continue;
|
||||
};
|
||||
current_lang_index
|
||||
.add(&embed.embed(desc)?, body.to_string())
|
||||
.map_err(E::msg)?;
|
||||
let Some(desc) = doc.get_arg("desc").and_then(|v| v.as_string()) else {
|
||||
continue;
|
||||
};
|
||||
let Some(body) = doc.get_arg("body").and_then(|v| v.as_string()) else {
|
||||
continue;
|
||||
};
|
||||
current_lang_index
|
||||
.add(&embed.embed(desc)?, body.to_string())
|
||||
.map_err(E::msg)?;
|
||||
}
|
||||
}
|
||||
|
||||
for index in dict.values_mut() {
|
||||
@@ -67,45 +54,42 @@ async fn main() -> Result<()> {
|
||||
.map_err(E::msg)?;
|
||||
}
|
||||
|
||||
// v2
|
||||
let paths = glob::glob("./snippets/v2/*/*.kdl")?;
|
||||
let mut v2_dict = HashMap::new();
|
||||
let mut v2_mutations_collection = vec![];
|
||||
for (i, path) in paths.enumerate() {
|
||||
let path = path?;
|
||||
let parent = path_to_parent_base(&path)?;
|
||||
let mut refactor_dict = HashMap::new();
|
||||
let mut mutations_collection = vec![];
|
||||
for (language, paths) in sources::rule_files(args.snippets.join("refactor"))? {
|
||||
for path in paths {
|
||||
let mutations = mutation::from_path(path)?;
|
||||
let current_lang_index = refactor_dict
|
||||
.entry(language.clone())
|
||||
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
|
||||
|
||||
let mutations = mutation::from_path(path)?;
|
||||
let current_lang_index = v2_dict
|
||||
.entry(parent)
|
||||
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
|
||||
|
||||
current_lang_index
|
||||
.add(&embed.embed(&mutations.description)?, i)
|
||||
.map_err(E::msg)?;
|
||||
v2_mutations_collection.push(mutations);
|
||||
current_lang_index
|
||||
.add(&embed.embed(&mutations.description)?, mutations_collection.len())
|
||||
.map_err(E::msg)?;
|
||||
mutations_collection.push(mutations);
|
||||
}
|
||||
}
|
||||
|
||||
for index in v2_dict.values_mut() {
|
||||
for index in refactor_dict.values_mut() {
|
||||
index.build(Euclidean).map_err(E::msg)?;
|
||||
}
|
||||
|
||||
let appstate = State {
|
||||
let appstate = State::new(
|
||||
embed,
|
||||
generate: state::Generate { dict },
|
||||
refactor: state::Refactor {
|
||||
dict: v2_dict,
|
||||
mutations_collection: v2_mutations_collection,
|
||||
state::Generate { dict },
|
||||
state::Refactor {
|
||||
dict: refactor_dict,
|
||||
mutations_collection,
|
||||
},
|
||||
};
|
||||
);
|
||||
|
||||
let stdin = tokio::io::stdin();
|
||||
let stdout = tokio::io::stdout();
|
||||
|
||||
let (service, socket) = LspService::new(|client| lsp::Backend {
|
||||
client,
|
||||
body: Arc::new(Mutex::new(String::default())),
|
||||
appstate
|
||||
body: Arc::new(Mutex::new(HashMap::default())),
|
||||
appstate,
|
||||
});
|
||||
Server::new(stdin, stdout, socket).serve(service).await;
|
||||
Ok(())
|
||||
|
||||
25
src/sources.rs
Normal file
25
src/sources.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use std::{fs, io, path::{Path, PathBuf}, collections::HashMap};
|
||||
|
||||
pub fn rule_files<P: AsRef<Path>>(path: P) -> io::Result<HashMap<String, Vec<PathBuf>>> {
|
||||
let per_language_dirs: Vec<_> = fs::read_dir(path)?
|
||||
.filter_map(|res| res.ok())
|
||||
.map(|direntry| direntry.path())
|
||||
.filter(|dir| dir.is_dir()).collect();
|
||||
|
||||
let mut basename_to_paths = HashMap::new();
|
||||
|
||||
for language_dir in per_language_dirs {
|
||||
let Some(dirname) = language_dir.file_stem().and_then(|v|v.to_str()).map(|v| v.to_string()) else {
|
||||
continue;
|
||||
};
|
||||
let rule_file_paths: Vec<_> = fs::read_dir(&language_dir)?
|
||||
.filter_map(|res| res.ok())
|
||||
.map(|entry| entry.path())
|
||||
.filter(|file| file.is_file() && file.extension().is_some_and(|ext| ext == "kdl"))
|
||||
.map(|path| path.to_path_buf())
|
||||
.collect();
|
||||
basename_to_paths.insert(dirname, rule_file_paths);
|
||||
}
|
||||
Ok(basename_to_paths)
|
||||
}
|
||||
// fn prebuilt_index();
|
||||
40
src/state.rs
40
src/state.rs
@@ -1,10 +1,10 @@
|
||||
use crate::mutation;
|
||||
use derive_more::Display;
|
||||
use derive_more::Error;
|
||||
use tree_sitter::Parser;
|
||||
use std::collections::HashMap;
|
||||
use hora::index::hnsw_idx::HNSWIndex;
|
||||
use hora::core::ann_index::ANNIndex;
|
||||
use crate::mutation;
|
||||
use hora::index::hnsw_idx::HNSWIndex;
|
||||
use std::collections::HashMap;
|
||||
use tree_sitter::Parser;
|
||||
|
||||
#[derive(Debug, Display, Error)]
|
||||
pub enum Error {
|
||||
@@ -25,7 +25,9 @@ impl Refactor {
|
||||
fn get_lang(s: &str) -> Result<tree_sitter::Language, Error> {
|
||||
Ok(match s {
|
||||
"go" => tree_sitter_go::LANGUAGE,
|
||||
"c" => tree_sitter_c::LANGUAGE,
|
||||
"c" | "h" => tree_sitter_c::LANGUAGE,
|
||||
"cpp" | "hpp" => tree_sitter_cpp::LANGUAGE,
|
||||
"js" | "ts" => tree_sitter_javascript::LANGUAGE,
|
||||
"rs" => tree_sitter_rust::LANGUAGE,
|
||||
_ => return Err(Error::UnknownLang),
|
||||
}
|
||||
@@ -68,7 +70,8 @@ impl Refactor {
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
collection_index = index,
|
||||
"failed to apply mutations from collection {}", e
|
||||
"failed to apply mutations from collection {}",
|
||||
e
|
||||
);
|
||||
None
|
||||
}
|
||||
@@ -79,7 +82,7 @@ impl Refactor {
|
||||
}
|
||||
}
|
||||
pub struct Generate {
|
||||
pub dict: HashMap<String, HNSWIndex<f32, String>>
|
||||
pub dict: HashMap<String, HNSWIndex<f32, String>>,
|
||||
}
|
||||
|
||||
impl Generate {
|
||||
@@ -92,14 +95,19 @@ impl Generate {
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
// TODO: create new constructor and private these fields
|
||||
pub embed: crate::embed::Embed,
|
||||
pub generate: Generate,
|
||||
pub refactor: Refactor,
|
||||
embed: crate::embed::Embed,
|
||||
generate: Generate,
|
||||
refactor: Refactor,
|
||||
}
|
||||
|
||||
impl State {
|
||||
|
||||
pub fn new(embed: crate::embed::Embed, generate: Generate, refactor: Refactor) -> Self {
|
||||
Self {
|
||||
embed,
|
||||
generate,
|
||||
refactor,
|
||||
}
|
||||
}
|
||||
pub fn generate(&self, lang: &str, prompt: &str, top_k: usize) -> Result<Vec<String>, Error> {
|
||||
let Ok(target) = self.embed.embed(prompt) else {
|
||||
return Err(Error::EmbedFailed);
|
||||
@@ -108,7 +116,13 @@ impl State {
|
||||
self.generate.search(lang, &target, top_k)
|
||||
}
|
||||
|
||||
pub fn refactor(&self, lang: &str, prompt: &str, body: &str, top_k: usize) -> Result<Vec<String>, Error> {
|
||||
pub fn refactor(
|
||||
&self,
|
||||
lang: &str,
|
||||
prompt: &str,
|
||||
body: &str,
|
||||
top_k: usize,
|
||||
) -> Result<Vec<String>, Error> {
|
||||
let Ok(target) = self.embed.embed(prompt) else {
|
||||
return Err(Error::EmbedFailed);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user