Compare commits
48 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fcf98006c4 | ||
|
|
6eab02575a | ||
|
|
80fe8a3f16 | ||
|
|
60256c06cc | ||
|
|
caf4f51d22 | ||
|
|
a543e80a04 | ||
|
|
f0c137ade4 | ||
|
|
7b0f818d38 | ||
|
|
4195cbb734 | ||
|
|
e5602c688c | ||
|
|
89e4c3b5fb | ||
|
|
e24e62873f | ||
|
|
0b9ab89f35 | ||
|
|
650329206d | ||
|
|
7d9c3a448f | ||
|
|
633c1a206b | ||
|
|
ab4c62fcf4 | ||
|
|
6ff9ba9d16 | ||
|
|
e348b9a830 | ||
|
|
d359121afd | ||
|
|
4abd2cffac | ||
|
|
daccd63006 | ||
|
|
87e096f0bc | ||
|
|
91d2640c11 | ||
|
|
ec3b89f455 | ||
|
|
c734c81a04 | ||
|
|
e7cae348a1 | ||
|
|
faea784d8f | ||
|
|
e8970f21ff | ||
|
|
a1445b2f03 | ||
|
|
8f5e618841 | ||
|
|
996142c8dd | ||
|
|
b32ed48471 | ||
|
|
a4b77cd40b | ||
|
|
c137a6c586 | ||
|
|
7e6300f8e9 | ||
|
|
949e5f5df4 | ||
|
|
08df630a2b | ||
|
|
c8a62b9b42 | ||
|
|
34aeae0496 | ||
|
|
bc4f29393e | ||
|
|
dfa414163d | ||
|
|
a6b9910d39 | ||
|
|
cf11c5774d | ||
|
|
291e2e1832 | ||
|
|
bd772eec48 | ||
|
|
56dc5f4393 | ||
|
|
5dadaca767 |
8
.github/dependabot.yml
vendored
Normal file
8
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
- package-ecosystem: "cargo" # See documentation for possible values
|
||||||
|
directory: "/" # Location of package manifests
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
open-pull-requests-limit: 4
|
||||||
|
|
||||||
26
.github/workflows/release.yml
vendored
Normal file
26
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [created]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
release:
|
||||||
|
name: release ${{ matrix.target }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- target: x86_64-pc-windows-gnu
|
||||||
|
archive: zip
|
||||||
|
- target: x86_64-unknown-linux-musl
|
||||||
|
archive: tar.zst
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@master
|
||||||
|
- name: Compile and release
|
||||||
|
uses: rust-build/rust-build.action@v1.4.5
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
with:
|
||||||
|
RUSTTARGET: ${{ matrix.target }}
|
||||||
|
ARCHIVE_TYPES: ${{ matrix.archive }}
|
||||||
|
TOOLCHAIN_VERSION: stable
|
||||||
22
.github/workflows/rust.yml
vendored
Normal file
22
.github/workflows/rust.yml
vendored
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
name: Build and test
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "master" ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ "master" ]
|
||||||
|
|
||||||
|
env:
|
||||||
|
CARGO_TERM_COLOR: always
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Build
|
||||||
|
run: cargo build --verbose
|
||||||
|
- name: Run tests
|
||||||
|
run: cargo test --verbose
|
||||||
@@ -1,6 +1,11 @@
|
|||||||
[language-server.silos]
|
[language-server.silos]
|
||||||
command = "./target/debug/silos"
|
command = "./target/debug/silos"
|
||||||
|
args = ["lsp"]
|
||||||
|
|
||||||
[[language]]
|
[[language]]
|
||||||
name = "go"
|
name = "go"
|
||||||
language-servers = [ { name = "silos" } ]
|
language-servers = [ { name = "silos" }, "gopls" ]
|
||||||
|
|
||||||
|
[[language]]
|
||||||
|
name = "rust"
|
||||||
|
language-servers = [ ]
|
||||||
|
|||||||
16
.vscode/settings.json
vendored
Normal file
16
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"vscode-lspconfig.serverConfigurations": [
|
||||||
|
{
|
||||||
|
"name": "silos",
|
||||||
|
"document_selector": [
|
||||||
|
{
|
||||||
|
"language": "go"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"command": [
|
||||||
|
"silos"
|
||||||
|
"lsp"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
837
Cargo.lock
generated
837
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
16
Cargo.toml
16
Cargo.toml
@@ -1,28 +1,28 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "silos"
|
name = "silos"
|
||||||
version = "2.0.0"
|
version = "6.0.0"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
actix-web = "4.11.0"
|
|
||||||
anyhow = "1.0.98"
|
anyhow = "1.0.98"
|
||||||
candle-core = "0.9.1"
|
candle-core = "0.9.1"
|
||||||
candle-nn = "0.9.1"
|
candle-nn = "0.9.1"
|
||||||
candle-transformers = "0.9.1"
|
candle-transformers = "0.9.1"
|
||||||
clap = { version = "4.5.39", features = ["derive"] }
|
clap = { version = "4.5.45", features = ["derive"] }
|
||||||
derive_more = "2.0.1"
|
derive_more = { version = "2.0.1", features = ["display", "error"] }
|
||||||
glob = "0.3.2"
|
|
||||||
hf-hub = "0.4.2"
|
hf-hub = "0.4.2"
|
||||||
hora = "0.1.1"
|
hora = "0.1.1"
|
||||||
kdl = "6.3.4"
|
kdl = "6.3.4"
|
||||||
serde = "1.0.219"
|
|
||||||
serde_json = "1.0.140"
|
serde_json = "1.0.140"
|
||||||
tokenizers = "0.21.1"
|
tokenizers = "0.21.4"
|
||||||
tracing = "0.1.41"
|
tracing = "0.1.41"
|
||||||
tracing-subscriber = "0.3.19"
|
tracing-subscriber = "0.3.19"
|
||||||
tree-sitter = "0.25.6"
|
tree-sitter = "0.25.8"
|
||||||
tree-sitter-c = "0.24.1"
|
tree-sitter-c = "0.24.1"
|
||||||
tree-sitter-go = "0.23.4"
|
tree-sitter-go = "0.23.4"
|
||||||
tree-sitter-rust = "0.24.0"
|
tree-sitter-rust = "0.24.0"
|
||||||
tokio = { version = "1.45.1", features = ["io-std", "macros", "rt", "rt-multi-thread"] }
|
tokio = { version = "1.45.1", features = ["io-std", "macros", "rt", "rt-multi-thread"] }
|
||||||
tower-lsp = "0.20.0"
|
tower-lsp = "0.20.0"
|
||||||
|
tree-sitter-javascript = "0.23.1"
|
||||||
|
tree-sitter-cpp = "0.23.4"
|
||||||
|
knuffel = "3.2.0"
|
||||||
|
|||||||
111
README.md
111
README.md
@@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
Dumb, proomptable modular snippet search.
|
Dumb, proomptable modular snippet search.
|
||||||
|
|
||||||
## Getting started
|

|
||||||
|
|
||||||
There are no binary releases yet.
|
## Installation
|
||||||
|
|
||||||
|
You can download a binary from releases tab or build the project from source.
|
||||||
|
|
||||||
### From source
|
### From source
|
||||||
|
|
||||||
@@ -13,26 +15,48 @@ Prerequisites:
|
|||||||
- libc
|
- libc
|
||||||
- [rust toolchain](https://rustup.rs)
|
- [rust toolchain](https://rustup.rs)
|
||||||
|
|
||||||
Clone this repository and enter it
|
Clone this repository and build it.
|
||||||
|
|
||||||
``` sh
|
``` sh
|
||||||
git clone https://github.com/lavafroth/silos
|
cargo install --git https://github.com/lavafroth/silos
|
||||||
cd silos
|
|
||||||
```
|
```
|
||||||
|
|
||||||
``` sh
|
## Editor support
|
||||||
cargo r http
|
|
||||||
|
- Helix: Use the example `.helix` directory provided to run the LSP for files under `./examples/`.
|
||||||
|
- Neovim: Please follow [the official guide](https://neovim.io/doc/user/lsp.html).
|
||||||
|
- VSCode: Use the [vscode-lspconfig](https://marketplace.visualstudio.com/items?itemName=whtsht.vscode-lspconfig) extension with the `.vscode/settings.json` provided.
|
||||||
|
|
||||||
|
Make sure to modify the binary path in the example to where you have it on your system.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
- Write a comment above a paragraph of code, consider the example in examples/example.go
|
||||||
|
|
||||||
|
``` go
|
||||||
|
resumeFilename := "resume.pdf"
|
||||||
|
version := 3
|
||||||
|
// refactor: change the file basename to that of the parent
|
||||||
|
whereIsMyResume :=
|
||||||
|
filepath.Base(
|
||||||
|
documentsDirectory + "CV" + "_v" + strconv.Itoa(version) + "/" + resumeFilename)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- The comment must begin with either of
|
||||||
|
- `generate: `
|
||||||
|
- `refactor: `
|
||||||
|
- Select the code to be modified along with the comment above it.
|
||||||
|
- Trigger code actions. In helix, this is `space`, `a`.
|
||||||
|
- Select the option called "ask silos."
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
>
|
>
|
||||||
> Embedding defaults to using the CPU. You may use the `--gpu` flag with a GPU number to use a dedicated GPU.
|
> Embedding defaults to using the CPU. You may use the `--gpu` flag with a GPU number to use a dedicated GPU.
|
||||||
|
|
||||||
An HTTP REST API listens on port 8000 and can be queried for code snippets.
|
## `generate` snippets
|
||||||
|
|
||||||
### v1 API
|
- Stored in the KDL format inside per-language directories under `./snippets/v1`.
|
||||||
|
- They must conform to the following structure
|
||||||
V1 snippets are stored in the KDL format inside per-language directories under `./snippets/v1`. They must conform to the following structure
|
|
||||||
|
|
||||||
``` kdl
|
``` kdl
|
||||||
desc "describes the snippet"
|
desc "describes the snippet"
|
||||||
@@ -43,40 +67,24 @@ KDL supports arbitrary raw strings with as many `#`s before and after the quotes
|
|||||||
|
|
||||||
See the example snippet `./snippets/v1/go/simple_worker.kdl` in the go programming language.
|
See the example snippet `./snippets/v1/go/simple_worker.kdl` in the go programming language.
|
||||||
|
|
||||||
#### Querying
|
## `refactor` snippets
|
||||||
|
|
||||||
We recommend the `jo` CLI to easily generate JSON payloads for the API.
|
This API parses code into an AST (Abstract Syntax Tree) via tree-sitter and can perform subsequent mutations.
|
||||||
|
|
||||||
``` sh
|
### Supported Languages
|
||||||
jo desc="channeled worker in go" \
|
|
||||||
curl http://localhost:8000/api/v1/get --json @-
|
|
||||||
```
|
|
||||||
|
|
||||||
You must add the "in someLanguage" suffix to your query's description field. This was a bad design choice and will be deprecated in a later release.
|
|
||||||
|
|
||||||
#### Adding a snippet
|
|
||||||
|
|
||||||
``` sh
|
|
||||||
curl http://localhost:8000/api/v1/add --json \
|
|
||||||
'{ "desc": "Build an asynchronous shared mutable state", "lang": "rust", "body": "let object = Arc::new(Mutex::new(old));" }'
|
|
||||||
```
|
|
||||||
|
|
||||||
### v2 API
|
|
||||||
|
|
||||||
The v2 API leverages tree-sitter to parse code into an AST (Abstract Syntax Tree) and perform subsequent mutations on the code.
|
|
||||||
|
|
||||||
#### Supported Languages
|
|
||||||
|
|
||||||
- C
|
- C
|
||||||
- Rust
|
- Rust
|
||||||
- Go
|
- Go
|
||||||
|
- Javascript
|
||||||
|
- C++
|
||||||
|
|
||||||
#### Defining mutation collections
|
### Defining mutation collections
|
||||||
|
|
||||||
``` kdl
|
``` kdl
|
||||||
description "describes the mutation collection"
|
description "describes the mutation collection"
|
||||||
mutation {
|
mutation {
|
||||||
expression "some ((beautiful) @adjective) AST expression"
|
expression "(some ((beautiful) @adjective) AST expression) @root"
|
||||||
substitute {
|
substitute {
|
||||||
literal "hello"
|
literal "hello"
|
||||||
capture "adjective"
|
capture "adjective"
|
||||||
@@ -85,7 +93,7 @@ mutation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mutation {
|
mutation {
|
||||||
expression "another"
|
expression "(another) @root"
|
||||||
substitute {
|
substitute {
|
||||||
literal "multiple mutations work"
|
literal "multiple mutations work"
|
||||||
literal "as long as their expression"
|
literal "as long as their expression"
|
||||||
@@ -96,31 +104,30 @@ mutation {
|
|||||||
|
|
||||||
- `description`: A textual description of the mutation collection.
|
- `description`: A textual description of the mutation collection.
|
||||||
- `mutation`: Defines individual code changes.
|
- `mutation`: Defines individual code changes.
|
||||||
- `expression`: Uses tree-sitter to match and capture AST nodes with `@` prefixes, The special `@root` node is reserved for the entire expression.
|
- `expression`: Uses tree-sitter to match and capture AST nodes with `@` prefixes,
|
||||||
|
- The special `@root` node must be specify the expression to be replaced.
|
||||||
- `substitute`: Constructs the modified code using literals and captured arguments.
|
- `substitute`: Constructs the modified code using literals and captured arguments.
|
||||||
|
|
||||||
See the example mutation collection in `./snippets/v2/go/mutations.kdl`.
|
See the example mutation collection in `./snippets/v2/go/filepath-parent.kdl`.
|
||||||
|
|
||||||
#### Querying
|
- The API performs a single-pass substitution based on the closest matching mutation.
|
||||||
|
- Captured groups are used within the `substitute` block and the mutated code is returned.
|
||||||
|
|
||||||
``` sh
|
> Every capture group must contain the largest atom to be operated on.
|
||||||
jo body=@examples/example.go \
|
For example: if you wish to operate on elements of an array, capture each identifier inside the array
|
||||||
desc='change the current filepath to the parent filepath in go' \
|
|
||||||
| curl http://localhost:8000/api/v2/get --json @-
|
Correct way: Here the `array` and `identifier` only hints about where the expression `root` lies.
|
||||||
|
|
||||||
|
```
|
||||||
|
(array (identifier @root))
|
||||||
```
|
```
|
||||||
|
|
||||||
V2 queries have the following fields
|
Incorrect way: Here the root expression matches the block all the array elements inside the braces, not each element.
|
||||||
|
|
||||||
- `desc`: Description of the query.
|
```
|
||||||
- `body`: The code to be parsed and modified.
|
(array ((identifier)*) @entire-block-capture) @root
|
||||||
|
```
|
||||||
The API performs a single-pass substitution based on the closest matching mutation. Captured groups are used within the `substitute` block and the mutated code is returned in the response JSON `body` field.
|
|
||||||
|
|
||||||
**Further reading**
|
**Further reading**
|
||||||
|
|
||||||
- [tree-sitter query snytax](https://tree-sitter.github.io/tree-sitter/using-parsers/queries/1-syntax.html) to create mutation expressions.
|
- [tree-sitter query snytax](https://tree-sitter.github.io/tree-sitter/using-parsers/queries/1-syntax.html) to create mutation expressions.
|
||||||
- [jo](https://github.com/jpmens/jo) to build the JSON body from a file.
|
|
||||||
|
|
||||||
## Coming soon
|
|
||||||
|
|
||||||
An LSP to provide Silos code actions for a given selection.
|
|
||||||
|
|||||||
BIN
assets/preview.gif
Normal file
BIN
assets/preview.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.9 MiB |
6
flake.lock
generated
6
flake.lock
generated
@@ -2,11 +2,11 @@
|
|||||||
"nodes": {
|
"nodes": {
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1749776303,
|
"lastModified": 1755020227,
|
||||||
"narHash": "sha256-OHibOvVwKqO1qvRg0r3agtd1EagW4THBcoWT7QGgcNo=",
|
"narHash": "sha256-gGmm+h0t6rY88RPTaIm3su95QvQIVjAJx558YUG4Id8=",
|
||||||
"owner": "NixOS",
|
"owner": "NixOS",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "6e7721e37bf00fa7ea44ac3cfc9d2411284ec3ef",
|
"rev": "695d5db1b8b20b73292501683a524e0bd79074fb",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|||||||
13
snippets/refactor/go/base64.kdl
Normal file
13
snippets/refactor/go/base64.kdl
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
description "base64 import"
|
||||||
|
mutation {
|
||||||
|
expression "(import_spec_list ((import_spec)* @spec)) @root"
|
||||||
|
substitute {
|
||||||
|
literal "("
|
||||||
|
literal "\n"
|
||||||
|
capture "spec"
|
||||||
|
literal "\n"
|
||||||
|
literal #""base64""#
|
||||||
|
literal "\n"
|
||||||
|
literal ")"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,7 +4,7 @@ mutation {
|
|||||||
(call_expression
|
(call_expression
|
||||||
function: (_) @func (#eq? @func "filepath.Base")
|
function: (_) @func (#eq? @func "filepath.Base")
|
||||||
arguments: (_) @args
|
arguments: (_) @args
|
||||||
)
|
) @root
|
||||||
"""
|
"""
|
||||||
substitute {
|
substitute {
|
||||||
literal "filepath.Base(filepath.Dir(filepath.Clean"
|
literal "filepath.Base(filepath.Dir(filepath.Clean"
|
||||||
64
src/args.rs
64
src/args.rs
@@ -1,11 +1,15 @@
|
|||||||
use clap::Parser;
|
use clap::{Args, Parser, Subcommand};
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
pub(crate) struct Args {
|
pub(crate) struct Cli {
|
||||||
/// The mode to run the server in. Defaults to LSP. The HTTP REST API can be run by specifying `http` or `http:port`. For example: `http:7047`
|
#[command(subcommand)]
|
||||||
pub(crate) mode: Option<String>,
|
pub command: Command,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Args, Debug)]
|
||||||
|
pub(crate) struct Lsp {
|
||||||
/// Run on the Nth GPU device.
|
/// Run on the Nth GPU device.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub(crate) gpu: Option<usize>,
|
pub(crate) gpu: Option<usize>,
|
||||||
@@ -17,14 +21,41 @@ pub(crate) struct Args {
|
|||||||
/// Revision or branch.
|
/// Revision or branch.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub(crate) revision: Option<String>,
|
pub(crate) revision: Option<String>,
|
||||||
|
|
||||||
|
/// Path to the directory containing `generate` and `refactor` snippets.
|
||||||
|
#[arg(long, default_value = "./snippets")]
|
||||||
|
pub(crate) snippets: std::path::PathBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum RunMode {
|
#[derive(Args, Debug)]
|
||||||
Http(u16),
|
pub struct DumpExpression {
|
||||||
Lsp,
|
pub path: PathBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
#[derive(Args, Debug)]
|
||||||
|
pub struct ShowCaptures {
|
||||||
|
pub path: PathBuf,
|
||||||
|
pub expression: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Subcommand, Debug)]
|
||||||
|
pub enum Ast {
|
||||||
|
/// Dump the S expression for a given source file
|
||||||
|
DumpExpression(DumpExpression),
|
||||||
|
/// Show what parts of a source file gets captured by an S expression
|
||||||
|
ShowCaptures(ShowCaptures),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Subcommand, Debug)]
|
||||||
|
pub enum Command {
|
||||||
|
/// quick actions to dump, modify and verify abstract syntax trees
|
||||||
|
#[command(subcommand)]
|
||||||
|
Ast(Ast),
|
||||||
|
/// spawn a language server for use with a text editor
|
||||||
|
Lsp(Lsp),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Lsp {
|
||||||
pub(crate) fn resolve_model_and_revision(&self) -> (String, String) {
|
pub(crate) fn resolve_model_and_revision(&self) -> (String, String) {
|
||||||
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
|
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
|
||||||
let default_revision = "refs/pr/21".to_string();
|
let default_revision = "refs/pr/21".to_string();
|
||||||
@@ -36,21 +67,4 @@ impl Args {
|
|||||||
(None, None) => (default_model, default_revision),
|
(None, None) => (default_model, default_revision),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub(crate) fn mode(&self) -> RunMode {
|
|
||||||
let Some(http) = &self.mode else {
|
|
||||||
return RunMode::Lsp;
|
|
||||||
};
|
|
||||||
if http == "http" {
|
|
||||||
return RunMode::Http(8000);
|
|
||||||
}
|
|
||||||
let Some(port) = http.strip_prefix("http:") else {
|
|
||||||
return RunMode::Lsp;
|
|
||||||
};
|
|
||||||
|
|
||||||
let Ok(port) = port.parse() else {
|
|
||||||
return RunMode::Lsp;
|
|
||||||
};
|
|
||||||
|
|
||||||
RunMode::Http(port)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
45
src/embed.rs
45
src/embed.rs
@@ -7,10 +7,24 @@ use hf_hub::Repo;
|
|||||||
use hf_hub::RepoType;
|
use hf_hub::RepoType;
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
use tokenizers::DecoderWrapper;
|
||||||
|
use tokenizers::ModelWrapper;
|
||||||
|
use tokenizers::NormalizerWrapper;
|
||||||
|
use tokenizers::PostProcessorWrapper;
|
||||||
|
use tokenizers::PreTokenizerWrapper;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
use tokenizers::TokenizerImpl;
|
||||||
|
|
||||||
pub struct Embed {
|
pub struct Embed {
|
||||||
model: BertModel,
|
model: BertModel,
|
||||||
tokenizer: Tokenizer,
|
pub hidden_size: usize,
|
||||||
|
tokenizer: TokenizerImpl<
|
||||||
|
ModelWrapper,
|
||||||
|
NormalizerWrapper,
|
||||||
|
PreTokenizerWrapper,
|
||||||
|
PostProcessorWrapper,
|
||||||
|
DecoderWrapper,
|
||||||
|
>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Embed {
|
impl Embed {
|
||||||
@@ -26,12 +40,22 @@ impl Embed {
|
|||||||
|
|
||||||
let config = std::fs::read_to_string(config_path)?;
|
let config = std::fs::read_to_string(config_path)?;
|
||||||
let config: Config = serde_json::from_str(&config)?;
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
|
let mut tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
|
||||||
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, &device)? };
|
||||||
let model = BertModel::load(vb, &config)?;
|
let model = BertModel::load(vb, &config)?;
|
||||||
|
|
||||||
Ok(Embed { model, tokenizer })
|
let tokenizer = tokenizer
|
||||||
|
.with_padding(None)
|
||||||
|
.with_truncation(None)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
Ok(Embed {
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
hidden_size: config.hidden_size,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn download_model_files(model_id: &str, revision: &str) -> Result<(PathBuf, PathBuf, PathBuf)> {
|
fn download_model_files(model_id: &str, revision: &str) -> Result<(PathBuf, PathBuf, PathBuf)> {
|
||||||
@@ -45,14 +69,9 @@ impl Embed {
|
|||||||
Ok((config, tokenizer, weights))
|
Ok((config, tokenizer, weights))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn embed(&mut self, prompt: &str) -> Result<Vec<f32>> {
|
pub(crate) fn embed(&self, prompt: &str) -> Result<Vec<f32>> {
|
||||||
let tokenizer = self
|
let tokens = self
|
||||||
.tokenizer
|
.tokenizer
|
||||||
.with_padding(None)
|
|
||||||
.with_truncation(None)
|
|
||||||
.map_err(E::msg)?;
|
|
||||||
|
|
||||||
let tokens = tokenizer
|
|
||||||
.encode(prompt, true)
|
.encode(prompt, true)
|
||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
@@ -62,9 +81,9 @@ impl Embed {
|
|||||||
let token_type_ids = token_ids.zeros_like()?;
|
let token_type_ids = token_ids.zeros_like()?;
|
||||||
|
|
||||||
let embeddings = self.model.forward(&token_ids, &token_type_ids, None)?;
|
let embeddings = self.model.forward(&token_ids, &token_type_ids, None)?;
|
||||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
let embeddings = normalize_l2(&embeddings.sum(1)?)?
|
||||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
.reshape(self.hidden_size)?
|
||||||
let embeddings = normalize_l2(&embeddings)?.reshape(384)?.to_vec1::<f32>()?;
|
.to_vec1::<f32>()?;
|
||||||
Ok(embeddings)
|
Ok(embeddings)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
134
src/lsp.rs
134
src/lsp.rs
@@ -1,20 +1,16 @@
|
|||||||
use crate::StateWrapper;
|
|
||||||
use crate::v2;
|
|
||||||
use actix_web::web::Data;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tower_lsp::lsp_types::*;
|
use tower_lsp::lsp_types::*;
|
||||||
use tower_lsp::{Client, LanguageServer};
|
use tower_lsp::{Client, LanguageServer};
|
||||||
use tracing::error;
|
|
||||||
|
|
||||||
pub struct Backend {
|
pub struct Backend {
|
||||||
pub client: Client,
|
pub client: Client,
|
||||||
pub body: Arc<Mutex<String>>,
|
pub body: Arc<Mutex<HashMap<Url, String>>>,
|
||||||
pub appstate: Data<StateWrapper>,
|
pub appstate: crate::State,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn string_range_index(s: &str, r: Range) -> &str {
|
fn string_range_index(s: &str, r: Range) -> &str {
|
||||||
let mut newline_count = 0;
|
let mut newline_count = 0;
|
||||||
let mut start = None;
|
let mut start = None;
|
||||||
let mut end = None;
|
let mut end = None;
|
||||||
@@ -44,11 +40,9 @@ impl LanguageServer for Backend {
|
|||||||
text_document_sync: Some(TextDocumentSyncCapability::Kind(
|
text_document_sync: Some(TextDocumentSyncCapability::Kind(
|
||||||
TextDocumentSyncKind::FULL,
|
TextDocumentSyncKind::FULL,
|
||||||
)),
|
)),
|
||||||
code_action_provider: Some(
|
code_action_provider: Some(CodeActionProviderCapability::Options(
|
||||||
tower_lsp::lsp_types::CodeActionProviderCapability::Options(
|
CodeActionOptions::default(),
|
||||||
CodeActionOptions::default(),
|
)),
|
||||||
),
|
|
||||||
),
|
|
||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@@ -66,13 +60,18 @@ impl LanguageServer for Backend {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn did_open(&self, params: DidOpenTextDocumentParams) {
|
async fn did_open(&self, params: DidOpenTextDocumentParams) {
|
||||||
// TODO: build an index for multiple documents in workdir
|
self.body
|
||||||
*self.body.lock().await = params.text_document.text;
|
.lock()
|
||||||
|
.await
|
||||||
|
.insert(params.text_document.uri, params.text_document.text);
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn did_change(&self, params: DidChangeTextDocumentParams) {
|
async fn did_change(&self, params: DidChangeTextDocumentParams) {
|
||||||
if let Some(body) = params.content_changes.into_iter().next() {
|
if let Some(body) = params.content_changes.into_iter().next() {
|
||||||
*self.body.lock().await = body.text;
|
self.body
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.insert(params.text_document.uri, body.text);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,48 +80,59 @@ impl LanguageServer for Backend {
|
|||||||
params: CodeActionParams,
|
params: CodeActionParams,
|
||||||
) -> tower_lsp::jsonrpc::Result<Option<CodeActionResponse>> {
|
) -> tower_lsp::jsonrpc::Result<Option<CodeActionResponse>> {
|
||||||
let uri = params.text_document.uri;
|
let uri = params.text_document.uri;
|
||||||
let extension = url_extension(&uri);
|
let Some(lang) = url_extension(&uri) else {
|
||||||
let body = self.body.lock().await.to_string();
|
self.client
|
||||||
|
.log_message(
|
||||||
let range = params.range;
|
MessageType::ERROR,
|
||||||
let new_text = string_range_index(&body, range);
|
"unable to determine filetype, file has no extension",
|
||||||
let Some((_before, after)) = new_text.split_once("silos: ") else {
|
)
|
||||||
return Ok(None);
|
.await;
|
||||||
};
|
|
||||||
let Some((desc, _after)) = after.split_once("\n") else {
|
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
|
||||||
let (prompt, lang) = if let Some(ext) = extension {
|
let body_locked = self.body.lock().await;
|
||||||
(desc, ext)
|
let Some(body) = body_locked.get(&uri) else {
|
||||||
} else if let Some((prompt, lang)) = desc.rsplit_once(" in ") {
|
return Ok(None);
|
||||||
(prompt, lang.to_string())
|
};
|
||||||
} else {
|
let mut range = params.range;
|
||||||
error!("{}", v2::errors::Error::MissingSuffix);
|
let selected_text = string_range_index(body, range);
|
||||||
|
|
||||||
|
let Some(comment) = ParsedAction::new(selected_text) else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
|
||||||
let closest_matches =
|
let action_response = match comment.action {
|
||||||
match v2::api::closest_mutation(&lang, prompt, &body, 1, &self.appstate) {
|
Action::Generate => {
|
||||||
Ok(v) => v,
|
range.start = range.end;
|
||||||
Err(e) => {
|
self.appstate
|
||||||
error!("{}", e);
|
.generate(&lang, comment.description, 1)
|
||||||
return Ok(None);
|
.map(|v| v.into_iter().map(|s| format!("{s}\n")).collect())
|
||||||
}
|
.map_err(|e| e.to_string())
|
||||||
};
|
}
|
||||||
|
Action::Refactor => self
|
||||||
|
.appstate
|
||||||
|
.refactor(&lang, comment.description, selected_text, 1)
|
||||||
|
.map_err(|e| e.to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
let Some(closest) = closest_matches.into_iter().next() else {
|
let closest_matches = match action_response {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
self.client
|
||||||
|
.log_message(MessageType::ERROR, e.to_string())
|
||||||
|
.await;
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(new_text) = closest_matches.into_iter().next() else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
let text_edit = TextEdit {
|
let text_edit = TextEdit { range, new_text };
|
||||||
range,
|
|
||||||
new_text: closest,
|
|
||||||
};
|
|
||||||
let changes: HashMap<Url, _> = [(uri, vec![text_edit])].into_iter().collect();
|
let changes: HashMap<Url, _> = [(uri, vec![text_edit])].into_iter().collect();
|
||||||
let edit = Some(WorkspaceEdit {
|
let edit = Some(WorkspaceEdit {
|
||||||
changes: Some(changes),
|
changes: Some(changes),
|
||||||
document_changes: None,
|
..Default::default()
|
||||||
change_annotations: None,
|
|
||||||
});
|
});
|
||||||
let actions = vec![CodeActionOrCommand::CodeAction(CodeAction {
|
let actions = vec![CodeActionOrCommand::CodeAction(CodeAction {
|
||||||
title: "ask silos".to_string(),
|
title: "ask silos".to_string(),
|
||||||
@@ -133,6 +143,40 @@ impl LanguageServer for Backend {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct ParsedAction<'a> {
|
||||||
|
action: Action,
|
||||||
|
description: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum Action {
|
||||||
|
Generate,
|
||||||
|
Refactor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> ParsedAction<'a> {
|
||||||
|
fn new(comment: &'a str) -> Option<ParsedAction<'a>> {
|
||||||
|
let upto_newline = match comment.rsplit_once("\n") {
|
||||||
|
Some((upto_newline, _discard)) => upto_newline,
|
||||||
|
None => comment,
|
||||||
|
};
|
||||||
|
let maybe_generate =
|
||||||
|
upto_newline
|
||||||
|
.split_once("generate: ")
|
||||||
|
.map(|(_discard, description)| ParsedAction {
|
||||||
|
action: Action::Generate,
|
||||||
|
description,
|
||||||
|
});
|
||||||
|
let maybe_refactor =
|
||||||
|
upto_newline
|
||||||
|
.split_once("refactor: ")
|
||||||
|
.map(|(_discard, description)| ParsedAction {
|
||||||
|
action: Action::Refactor,
|
||||||
|
description,
|
||||||
|
});
|
||||||
|
maybe_generate.or(maybe_refactor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn url_extension(u: &Url) -> Option<String> {
|
fn url_extension(u: &Url) -> Option<String> {
|
||||||
let file_path = u.to_file_path().ok()?;
|
let file_path = u.to_file_path().ok()?;
|
||||||
|
|
||||||
|
|||||||
170
src/main.rs
170
src/main.rs
@@ -1,10 +1,9 @@
|
|||||||
use actix_web::{App, HttpServer, web};
|
use anyhow::{Context, Error as E, Result};
|
||||||
use anyhow::{Context, Error as E, Result, bail};
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use hora::core::{ann_index::ANNIndex, metrics::Metric::Euclidean};
|
use hora::core::{ann_index::ANNIndex, metrics::Metric::Euclidean};
|
||||||
use hora::index::hnsw_idx::HNSWIndex;
|
use hora::index::hnsw_idx::HNSWIndex;
|
||||||
use kdl::KdlDocument;
|
use kdl::KdlDocument;
|
||||||
use state::{State, StateWrapper};
|
use state::{State, dump_expression};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
@@ -13,57 +12,69 @@ use tower_lsp::{LspService, Server};
|
|||||||
mod args;
|
mod args;
|
||||||
mod embed;
|
mod embed;
|
||||||
mod lsp;
|
mod lsp;
|
||||||
|
mod mutation;
|
||||||
|
mod sources;
|
||||||
mod state;
|
mod state;
|
||||||
mod v1;
|
|
||||||
mod v2;
|
|
||||||
|
|
||||||
fn path_to_parent_base(p: &std::path::Path) -> Result<String> {
|
#[tokio::main]
|
||||||
let Some(parent) = p
|
|
||||||
.parent()
|
|
||||||
.and_then(|v| v.file_name())
|
|
||||||
.and_then(|v| v.to_str())
|
|
||||||
.map(|v| v.to_string())
|
|
||||||
else {
|
|
||||||
bail!("failed to parse snippets path, make sure the directory structure is valid");
|
|
||||||
};
|
|
||||||
Ok(parent)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[actix_web::main]
|
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
tracing_subscriber::fmt::init();
|
tracing_subscriber::fmt::init();
|
||||||
let args = args::Args::parse();
|
let args = match args::Cli::parse().command {
|
||||||
let mode = args.mode();
|
args::Command::Ast(ast) => {
|
||||||
|
match ast {
|
||||||
|
args::Ast::DumpExpression(source_file) => {
|
||||||
|
println!("{}", dump_expression(&source_file.path)?);
|
||||||
|
}
|
||||||
|
args::Ast::ShowCaptures(show_captures) => {
|
||||||
|
let source = std::fs::read_to_string(&show_captures.path).unwrap();
|
||||||
|
let source_bytes = source.as_bytes();
|
||||||
|
let extension = show_captures.path.extension().unwrap().to_str().unwrap();
|
||||||
|
let langfn = state::Refactor::get_lang(extension).unwrap();
|
||||||
|
let mut parser = tree_sitter::Parser::new();
|
||||||
|
parser.set_language(&langfn).unwrap();
|
||||||
|
let tree = parser.parse(source_bytes, None).unwrap();
|
||||||
|
let root_node = tree.root_node();
|
||||||
|
let cooked = mutation::query(
|
||||||
|
root_node,
|
||||||
|
&show_captures.expression,
|
||||||
|
&langfn,
|
||||||
|
source_bytes,
|
||||||
|
);
|
||||||
|
println!("{:#?}", cooked);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
args::Command::Lsp(lsp) => lsp,
|
||||||
|
};
|
||||||
|
|
||||||
let (model_id, revision) = args.resolve_model_and_revision();
|
let (model_id, revision) = args.resolve_model_and_revision();
|
||||||
let mut embed = embed::Embed::new(args.gpu, &model_id, &revision)?;
|
|
||||||
|
let embed = embed::Embed::new(args.gpu, &model_id, &revision)?;
|
||||||
let mut dict = HashMap::default();
|
let mut dict = HashMap::default();
|
||||||
|
let dimensions = embed.hidden_size;
|
||||||
|
|
||||||
let paths = glob::glob("./snippets/v1/*/*.kdl")?;
|
for (language, paths) in sources::rule_files(args.snippets.join("generate"))? {
|
||||||
for path in paths {
|
for path in paths {
|
||||||
let path = path?;
|
let current_lang_index = dict
|
||||||
let parent = path_to_parent_base(&path)?;
|
.entry(language.clone())
|
||||||
|
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
|
||||||
|
|
||||||
let current_lang_index = dict.entry(parent).or_insert_with(|| {
|
let doc_str = std::fs::read_to_string(&path)?;
|
||||||
let dimension = 384;
|
let doc: KdlDocument = doc_str
|
||||||
let params = hora::index::hnsw_params::HNSWParams::<f32>::default();
|
.parse()
|
||||||
|
.context(format!("failed to parse KDL: {}", path.display()))?;
|
||||||
|
|
||||||
HNSWIndex::<f32, String>::new(dimension, ¶ms)
|
let Some(desc) = doc.get_arg("desc").and_then(|v| v.as_string()) else {
|
||||||
});
|
continue;
|
||||||
|
};
|
||||||
let doc_str = std::fs::read_to_string(&path)?;
|
let Some(body) = doc.get_arg("body").and_then(|v| v.as_string()) else {
|
||||||
let doc: KdlDocument = doc_str
|
continue;
|
||||||
.parse()
|
};
|
||||||
.context(format!("failed to parse KDL: {}", path.display()))?;
|
current_lang_index
|
||||||
|
.add(&embed.embed(desc)?, body.to_string())
|
||||||
let Some(desc) = doc.get_arg("desc").and_then(|v| v.as_string()) else {
|
.map_err(E::msg)?;
|
||||||
continue;
|
}
|
||||||
};
|
|
||||||
let Some(body) = doc.get_arg("body").and_then(|v| v.as_string()) else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
current_lang_index
|
|
||||||
.add(&embed.embed(desc)?, body.to_string())
|
|
||||||
.map_err(E::msg)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for index in dict.values_mut() {
|
for index in dict.values_mut() {
|
||||||
@@ -72,64 +83,45 @@ async fn main() -> Result<()> {
|
|||||||
.map_err(E::msg)?;
|
.map_err(E::msg)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// v2 stuff
|
let mut refactor_dict = HashMap::new();
|
||||||
let paths = glob::glob("./snippets/v2/*/*.kdl")?;
|
let mut mutations_collection = vec![];
|
||||||
let mut v2_dict = HashMap::new();
|
for (language, paths) in sources::rule_files(args.snippets.join("refactor"))? {
|
||||||
let mut v2_mutations_collection = vec![];
|
for path in paths {
|
||||||
for (i, path) in paths.enumerate() {
|
let mutations = mutation::from_path(path)?;
|
||||||
let path = path?;
|
let current_lang_index = refactor_dict
|
||||||
let parent = path_to_parent_base(&path)?;
|
.entry(language.clone())
|
||||||
|
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
|
||||||
|
|
||||||
let mutations = v2::mutation::from_path(path)?;
|
current_lang_index
|
||||||
let current_lang_index = v2_dict.entry(parent).or_insert_with(|| {
|
.add(
|
||||||
let dimension = 384;
|
&embed.embed(&mutations.description)?,
|
||||||
let params = hora::index::hnsw_params::HNSWParams::<f32>::default();
|
mutations_collection.len(),
|
||||||
|
)
|
||||||
HNSWIndex::<f32, usize>::new(dimension, ¶ms)
|
.map_err(E::msg)?;
|
||||||
});
|
mutations_collection.push(mutations);
|
||||||
|
}
|
||||||
current_lang_index
|
|
||||||
.add(&embed.embed(&mutations.description)?, i)
|
|
||||||
.map_err(E::msg)?;
|
|
||||||
v2_mutations_collection.push(mutations);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for index in v2_dict.values_mut() {
|
for index in refactor_dict.values_mut() {
|
||||||
index.build(Euclidean).map_err(E::msg)?;
|
index.build(Euclidean).map_err(E::msg)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let appstate = State {
|
let appstate = State::new(
|
||||||
embed,
|
embed,
|
||||||
v1: v1::api::State { dict },
|
state::Generate { dict },
|
||||||
v2: v2::api::State {
|
state::Refactor {
|
||||||
dict: v2_dict,
|
dict: refactor_dict,
|
||||||
mutations_collection: v2_mutations_collection,
|
mutations_collection,
|
||||||
},
|
},
|
||||||
};
|
);
|
||||||
|
|
||||||
let appstate_wrapped = web::Data::new(appstate.build());
|
|
||||||
|
|
||||||
if let args::RunMode::Http(port) = mode {
|
|
||||||
return HttpServer::new(move || {
|
|
||||||
App::new()
|
|
||||||
.app_data(appstate_wrapped.clone())
|
|
||||||
.service(v1::api::get_snippet)
|
|
||||||
.service(v1::api::add_snippet)
|
|
||||||
.service(v2::api::get_snippet)
|
|
||||||
})
|
|
||||||
.bind(("127.0.0.1", port))?
|
|
||||||
.run()
|
|
||||||
.await
|
|
||||||
.map_err(E::from);
|
|
||||||
};
|
|
||||||
|
|
||||||
let stdin = tokio::io::stdin();
|
let stdin = tokio::io::stdin();
|
||||||
let stdout = tokio::io::stdout();
|
let stdout = tokio::io::stdout();
|
||||||
|
|
||||||
let (service, socket) = LspService::new(|client| lsp::Backend {
|
let (service, socket) = LspService::new(|client| lsp::Backend {
|
||||||
client,
|
client,
|
||||||
body: Arc::new(Mutex::new(String::default())),
|
body: Arc::new(Mutex::new(HashMap::default())),
|
||||||
appstate: appstate_wrapped.clone(),
|
appstate,
|
||||||
});
|
});
|
||||||
Server::new(stdin, stdout, socket).serve(service).await;
|
Server::new(stdin, stdout, socket).serve(service).await;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -2,92 +2,34 @@ use std::collections::HashMap;
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
use tree_sitter::{Language, Node, Query, QueryCursor, StreamingIterator};
|
use tree_sitter::{Language, Node, Query, QueryCursor, StreamingIterator};
|
||||||
|
use anyhow::Result;
|
||||||
|
|
||||||
use anyhow::{Result, bail};
|
#[derive(knuffel::Decode, Debug)]
|
||||||
use kdl::KdlDocument;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Mutation {
|
|
||||||
pub expression: String,
|
|
||||||
pub substitute: Vec<Substitute>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct MutationCollection {
|
pub struct MutationCollection {
|
||||||
|
#[knuffel(child, unwrap(argument))]
|
||||||
pub description: String,
|
pub description: String,
|
||||||
pub mutations: Vec<Mutation>,
|
#[knuffel(children)]
|
||||||
|
mutations: Vec<Mutation>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(knuffel::Decode, Debug)]
|
||||||
pub enum Substitute {
|
struct Mutation {
|
||||||
Literal(String),
|
#[knuffel(child, unwrap(argument))]
|
||||||
Capture(String),
|
expression: String,
|
||||||
|
#[knuffel(child, unwrap(children))]
|
||||||
|
substitute: Vec<Substitute>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(knuffel::Decode, Debug)]
|
||||||
|
enum Substitute {
|
||||||
|
Literal(#[knuffel(argument)] String),
|
||||||
|
Capture(#[knuffel(argument)] String),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_path<P: AsRef<Path>>(path: P) -> Result<MutationCollection> {
|
pub fn from_path<P: AsRef<Path>>(path: P) -> Result<MutationCollection> {
|
||||||
let contents = std::fs::read_to_string(path)?;
|
let contents = std::fs::read_to_string(&path)?;
|
||||||
let doc: KdlDocument = contents.parse()?;
|
let val = knuffel::parse(path.as_ref().to_str().unwrap(), &contents)?;
|
||||||
let mut mutations = vec![];
|
Ok(val)
|
||||||
|
|
||||||
let mut description = None;
|
|
||||||
|
|
||||||
for node in doc.nodes() {
|
|
||||||
let node_name = node.name().value();
|
|
||||||
|
|
||||||
if node_name != "mutation" && node_name != "description" {
|
|
||||||
bail!(
|
|
||||||
"document root must only contain `mutation` or `description` nodes: got {node_name}"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if node_name == "description" {
|
|
||||||
description.replace(
|
|
||||||
node.entry(0)
|
|
||||||
.unwrap()
|
|
||||||
.value()
|
|
||||||
.as_string()
|
|
||||||
.unwrap()
|
|
||||||
.to_string(),
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let node = node.children().unwrap();
|
|
||||||
let Some(expression) = node.get_arg("expression").and_then(|v| v.as_string()) else {
|
|
||||||
bail!("mutation node must contain an expression");
|
|
||||||
};
|
|
||||||
let Some(substitute) = node.get("substitute") else {
|
|
||||||
bail!("mutation node must contain an substitute");
|
|
||||||
};
|
|
||||||
|
|
||||||
let children = substitute.children().unwrap().nodes();
|
|
||||||
let mut substitute = vec![];
|
|
||||||
for child in children {
|
|
||||||
let attrib = child.entry(0).unwrap().value().as_string().unwrap();
|
|
||||||
let substitutor = match child.name().value() {
|
|
||||||
"literal" => Substitute::Literal(attrib.to_string()),
|
|
||||||
"capture" => Substitute::Capture(attrib.to_string()),
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
|
|
||||||
substitute.push(substitutor);
|
|
||||||
}
|
|
||||||
|
|
||||||
let expression = format!("({expression}) @root");
|
|
||||||
|
|
||||||
mutations.push(Mutation {
|
|
||||||
expression,
|
|
||||||
substitute,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
let Some(description) = description else {
|
|
||||||
bail!("mutation collection contains no `description`");
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(MutationCollection {
|
|
||||||
description,
|
|
||||||
mutations,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn apply(
|
pub fn apply(
|
||||||
@@ -127,7 +69,7 @@ pub fn apply(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct QueryCooked {
|
pub struct QueryCooked {
|
||||||
captures: HashMap<String, String>,
|
captures: HashMap<String, String>,
|
||||||
end: usize,
|
end: usize,
|
||||||
start: usize,
|
start: usize,
|
||||||
@@ -152,7 +94,7 @@ fn split_at_indices<'a>(c: &'a [u8], idx: &[usize]) -> SplitMap<'a> {
|
|||||||
SplitMap { values, indices }
|
SplitMap { values, indices }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn query<'a>(
|
pub fn query<'a>(
|
||||||
node: Node<'a>,
|
node: Node<'a>,
|
||||||
expr: &'a str,
|
expr: &'a str,
|
||||||
lang: &Language,
|
lang: &Language,
|
||||||
@@ -164,6 +106,7 @@ fn query<'a>(
|
|||||||
let mut query_matches = qc.matches(&query, node, source_bytes);
|
let mut query_matches = qc.matches(&query, node, source_bytes);
|
||||||
|
|
||||||
let capture_names = query.capture_names();
|
let capture_names = query.capture_names();
|
||||||
|
// println!("names: {capture_names:#?}");
|
||||||
|
|
||||||
let mut cooked = vec![];
|
let mut cooked = vec![];
|
||||||
|
|
||||||
@@ -171,19 +114,36 @@ fn query<'a>(
|
|||||||
let mut capture_cooked = HashMap::new();
|
let mut capture_cooked = HashMap::new();
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
let mut end = 0;
|
let mut end = 0;
|
||||||
for cap in matcha.captures {
|
if matcha.captures.is_empty() {
|
||||||
let Some(name) = capture_names.get(cap.index as usize) else {
|
continue;
|
||||||
continue;
|
}
|
||||||
};
|
// println!("match {:#?}", matcha.id());
|
||||||
if *name == "root" {
|
|
||||||
start = cap.node.start_byte();
|
for (ix, name) in capture_names.iter().enumerate() {
|
||||||
end = cap.node.end_byte();
|
let nodes = matcha.nodes_for_capture_index(ix.try_into().unwrap());
|
||||||
|
let mut start_pos = None;
|
||||||
|
let mut end_pos = None;
|
||||||
|
// println!("matches for {name}");
|
||||||
|
for node in nodes {
|
||||||
|
if start_pos.is_none() {
|
||||||
|
start_pos.replace(node.start_byte());
|
||||||
|
}
|
||||||
|
end_pos.replace(node.end_byte());
|
||||||
|
// println!("hit {node:#?}");
|
||||||
|
}
|
||||||
|
if start_pos.or(end_pos).is_none() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
capture_cooked.insert(
|
if *name == "root" {
|
||||||
name.to_string(),
|
start = start_pos.unwrap();
|
||||||
cap.node.utf8_text(source_bytes).unwrap().to_string(),
|
end = end_pos.unwrap();
|
||||||
);
|
}
|
||||||
|
let range = start_pos.unwrap()..end_pos.unwrap();
|
||||||
|
// println!("match range for {name}: {:#?}", range);
|
||||||
|
let text_bytes = &source_bytes[range];
|
||||||
|
let text = std::str::from_utf8(text_bytes).unwrap();
|
||||||
|
// println!("text: {text}");
|
||||||
|
capture_cooked.insert(name.to_string(), text.to_string());
|
||||||
}
|
}
|
||||||
cooked.push(QueryCooked {
|
cooked.push(QueryCooked {
|
||||||
start,
|
start,
|
||||||
34
src/sources.rs
Normal file
34
src/sources.rs
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
fs, io,
|
||||||
|
path::{Path, PathBuf},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn rule_files<P: AsRef<Path>>(path: P) -> io::Result<HashMap<String, Vec<PathBuf>>> {
|
||||||
|
let per_language_dirs: Vec<_> = fs::read_dir(path)?
|
||||||
|
.filter_map(|res| res.ok())
|
||||||
|
.map(|direntry| direntry.path())
|
||||||
|
.filter(|dir| dir.is_dir())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut basename_to_paths = HashMap::new();
|
||||||
|
|
||||||
|
for language_dir in per_language_dirs {
|
||||||
|
let Some(dirname) = language_dir
|
||||||
|
.file_stem()
|
||||||
|
.and_then(|v| v.to_str())
|
||||||
|
.map(|v| v.to_string())
|
||||||
|
else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let rule_file_paths: Vec<_> = fs::read_dir(&language_dir)?
|
||||||
|
.filter_map(|res| res.ok())
|
||||||
|
.map(|entry| entry.path())
|
||||||
|
.filter(|file| file.is_file() && file.extension().is_some_and(|ext| ext == "kdl"))
|
||||||
|
.map(|path| path.to_path_buf())
|
||||||
|
.collect();
|
||||||
|
basename_to_paths.insert(dirname, rule_file_paths);
|
||||||
|
}
|
||||||
|
Ok(basename_to_paths)
|
||||||
|
}
|
||||||
|
// fn prebuilt_index();
|
||||||
153
src/state.rs
153
src/state.rs
@@ -1,19 +1,154 @@
|
|||||||
use std::sync::Mutex;
|
use crate::mutation;
|
||||||
|
use derive_more::Display;
|
||||||
|
use derive_more::Error;
|
||||||
|
use hora::core::ann_index::ANNIndex;
|
||||||
|
use hora::index::hnsw_idx::HNSWIndex;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::Path;
|
||||||
|
use tree_sitter::Parser;
|
||||||
|
|
||||||
pub struct StateWrapper {
|
#[derive(Debug, Display, Error)]
|
||||||
pub inner: Mutex<State>,
|
pub enum Error {
|
||||||
|
#[display("failed to embed prompt")]
|
||||||
|
EmbedFailed,
|
||||||
|
#[display("snippets were requested for an unknown language")]
|
||||||
|
UnknownLang,
|
||||||
|
#[display("failed to parse corpus of code to apply mutation on")]
|
||||||
|
SnippetParsing,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Refactor {
|
||||||
|
pub dict: HashMap<String, HNSWIndex<f32, usize>>,
|
||||||
|
pub mutations_collection: Vec<mutation::MutationCollection>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Refactor {
|
||||||
|
pub fn get_lang(s: &str) -> Result<tree_sitter::Language, Error> {
|
||||||
|
Ok(match s {
|
||||||
|
"go" => tree_sitter_go::LANGUAGE,
|
||||||
|
"c" | "h" => tree_sitter_c::LANGUAGE,
|
||||||
|
"cpp" | "hpp" => tree_sitter_cpp::LANGUAGE,
|
||||||
|
"js" | "ts" => tree_sitter_javascript::LANGUAGE,
|
||||||
|
"rs" => tree_sitter_rust::LANGUAGE,
|
||||||
|
_ => return Err(Error::UnknownLang),
|
||||||
|
}
|
||||||
|
.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn search(
|
||||||
|
&self,
|
||||||
|
lang: &str,
|
||||||
|
target: &[f32],
|
||||||
|
body: &str,
|
||||||
|
top_k: usize,
|
||||||
|
) -> Result<Vec<String>, Error> {
|
||||||
|
let langfn = Self::get_lang(lang)?;
|
||||||
|
let mut parser = Parser::new();
|
||||||
|
parser
|
||||||
|
.set_language(&langfn)
|
||||||
|
.map_err(|_| Error::UnknownLang)?;
|
||||||
|
|
||||||
|
let source_code = body;
|
||||||
|
let source_bytes = source_code.as_bytes();
|
||||||
|
let tree = parser
|
||||||
|
.parse(source_code, None)
|
||||||
|
.ok_or(Error::SnippetParsing)?;
|
||||||
|
let root_node = tree.root_node();
|
||||||
|
|
||||||
|
// search for k nearest neighbors
|
||||||
|
let collected = self.dict[lang]
|
||||||
|
.search(target, top_k)
|
||||||
|
.iter()
|
||||||
|
.filter_map(|&index| {
|
||||||
|
let applied = mutation::apply(
|
||||||
|
langfn.clone(),
|
||||||
|
source_bytes,
|
||||||
|
root_node,
|
||||||
|
&self.mutations_collection[index],
|
||||||
|
);
|
||||||
|
match applied {
|
||||||
|
Ok(v) => Some(v),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!(
|
||||||
|
collection_index = index,
|
||||||
|
"failed to apply mutations from collection {}",
|
||||||
|
e
|
||||||
|
);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(collected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dump_expression(path: &Path) -> Result<String, Error> {
|
||||||
|
let Some(lang) = path.extension() else {
|
||||||
|
return Err(Error::UnknownLang);
|
||||||
|
};
|
||||||
|
let lang = lang.to_str().ok_or(Error::UnknownLang)?;
|
||||||
|
let langfn = Refactor::get_lang(lang)?;
|
||||||
|
let mut parser = Parser::new();
|
||||||
|
parser
|
||||||
|
.set_language(&langfn)
|
||||||
|
.map_err(|_| Error::UnknownLang)?;
|
||||||
|
|
||||||
|
let source_code = std::fs::read_to_string(path).map_err(|_| Error::SnippetParsing)?;
|
||||||
|
let source_bytes = source_code.as_bytes();
|
||||||
|
let tree = parser
|
||||||
|
.parse(source_bytes, None)
|
||||||
|
.ok_or(Error::SnippetParsing)?;
|
||||||
|
let root_node = tree.root_node();
|
||||||
|
Ok(root_node.to_sexp().to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Generate {
|
||||||
|
pub dict: HashMap<String, HNSWIndex<f32, String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Generate {
|
||||||
|
fn search(&self, lang: &str, target: &[f32], top_k: usize) -> Result<Vec<String>, Error> {
|
||||||
|
let Some(snippets_for_lang) = self.dict.get(lang) else {
|
||||||
|
return Err(Error::UnknownLang);
|
||||||
|
};
|
||||||
|
Ok(snippets_for_lang.search(target, top_k))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct State {
|
pub struct State {
|
||||||
pub embed: crate::embed::Embed,
|
embed: crate::embed::Embed,
|
||||||
pub v1: crate::v1::api::State,
|
generate: Generate,
|
||||||
pub v2: crate::v2::api::State,
|
refactor: Refactor,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
pub fn build(self) -> StateWrapper {
|
pub fn new(embed: crate::embed::Embed, generate: Generate, refactor: Refactor) -> Self {
|
||||||
StateWrapper {
|
Self {
|
||||||
inner: Mutex::new(self),
|
embed,
|
||||||
|
generate,
|
||||||
|
refactor,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
pub fn generate(&self, lang: &str, prompt: &str, top_k: usize) -> Result<Vec<String>, Error> {
|
||||||
|
let Ok(target) = self.embed.embed(prompt) else {
|
||||||
|
return Err(Error::EmbedFailed);
|
||||||
|
};
|
||||||
|
|
||||||
|
self.generate.search(lang, &target, top_k)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn refactor(
|
||||||
|
&self,
|
||||||
|
lang: &str,
|
||||||
|
prompt: &str,
|
||||||
|
body: &str,
|
||||||
|
top_k: usize,
|
||||||
|
) -> Result<Vec<String>, Error> {
|
||||||
|
let Ok(target) = self.embed.embed(prompt) else {
|
||||||
|
return Err(Error::EmbedFailed);
|
||||||
|
};
|
||||||
|
|
||||||
|
self.refactor.search(lang, &target, body, top_k)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,86 +0,0 @@
|
|||||||
use hora::core::ann_index::ANNIndex;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use super::errors::Error;
|
|
||||||
use actix_web::{Responder, post, web};
|
|
||||||
use anyhow::Result;
|
|
||||||
use hora::index::hnsw_idx::HNSWIndex;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct SnippetRequest {
|
|
||||||
desc: String,
|
|
||||||
top_k: Option<usize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct State {
|
|
||||||
pub dict: HashMap<String, HNSWIndex<f32, String>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
pub struct SnippetResponse {
|
|
||||||
id: usize,
|
|
||||||
snippet: Snippet,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
pub struct Snippet {
|
|
||||||
lang: String,
|
|
||||||
desc: String,
|
|
||||||
body: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/api/v1/get")]
|
|
||||||
pub(crate) async fn get_snippet(
|
|
||||||
data: web::Data<crate::state::StateWrapper>,
|
|
||||||
snippet_request: web::Json<SnippetRequest>,
|
|
||||||
) -> Result<impl Responder, Error> {
|
|
||||||
let Some((prompt, lang)) = snippet_request.desc.rsplit_once(" in ") else {
|
|
||||||
return Err(Error::MissingSuffix);
|
|
||||||
};
|
|
||||||
|
|
||||||
let Ok(mut appstate) = data.inner.lock() else {
|
|
||||||
return Err(Error::Busy);
|
|
||||||
};
|
|
||||||
|
|
||||||
let Ok(target) = appstate.embed.embed(prompt) else {
|
|
||||||
return Err(Error::EmbedFailed);
|
|
||||||
};
|
|
||||||
|
|
||||||
let Some(snippets_for_lang) = appstate.v1.dict.get(lang) else {
|
|
||||||
return Err(Error::UnknownLang);
|
|
||||||
};
|
|
||||||
// search for k nearest neighbors
|
|
||||||
let closest = snippets_for_lang.search(&target, snippet_request.top_k.unwrap_or(1));
|
|
||||||
Ok(web::Json(closest))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/api/v1/add")]
|
|
||||||
pub(crate) async fn add_snippet(
|
|
||||||
data: web::Data<crate::state::StateWrapper>,
|
|
||||||
snippet: web::Json<Snippet>,
|
|
||||||
) -> Result<impl Responder, Error> {
|
|
||||||
let Ok(mut appstate) = data.inner.lock() else {
|
|
||||||
return Err(Error::Busy);
|
|
||||||
};
|
|
||||||
let Ok(embedding) = appstate.embed.embed(&snippet.desc) else {
|
|
||||||
return Err(Error::EmbedFailed);
|
|
||||||
};
|
|
||||||
let index = appstate
|
|
||||||
.v1
|
|
||||||
.dict
|
|
||||||
.entry(snippet.lang.clone())
|
|
||||||
.or_insert_with(|| {
|
|
||||||
let dimension = 384;
|
|
||||||
let params = hora::index::hnsw_params::HNSWParams::<f32>::default();
|
|
||||||
|
|
||||||
HNSWIndex::<f32, String>::new(dimension, ¶ms)
|
|
||||||
});
|
|
||||||
index.add(&embedding, snippet.body.clone()).unwrap();
|
|
||||||
index.build(hora::core::metrics::Metric::Euclidean).unwrap();
|
|
||||||
|
|
||||||
Ok(format!(
|
|
||||||
"{} {} {}",
|
|
||||||
snippet.body, snippet.lang, snippet.desc
|
|
||||||
))
|
|
||||||
}
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
use actix_web::{
|
|
||||||
HttpResponse, error,
|
|
||||||
http::{StatusCode, header::ContentType},
|
|
||||||
};
|
|
||||||
use derive_more::derive::{Display, Error};
|
|
||||||
use serde_json::json;
|
|
||||||
|
|
||||||
#[derive(Debug, Display, Error)]
|
|
||||||
pub enum Error {
|
|
||||||
#[display("the server is busy. come back later.")]
|
|
||||||
Busy,
|
|
||||||
#[display("end your request with ` in somelang`.")]
|
|
||||||
MissingSuffix,
|
|
||||||
#[display("failed to embed your prompt.")]
|
|
||||||
EmbedFailed,
|
|
||||||
#[display("snippets were requested for an unknown language")]
|
|
||||||
UnknownLang,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl error::ResponseError for Error {
|
|
||||||
fn error_response(&self) -> HttpResponse {
|
|
||||||
let message = json!({
|
|
||||||
"message": self.to_string(),
|
|
||||||
})
|
|
||||||
.to_string();
|
|
||||||
HttpResponse::build(self.status_code())
|
|
||||||
.insert_header(ContentType::json())
|
|
||||||
.body(message)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn status_code(&self) -> StatusCode {
|
|
||||||
match *self {
|
|
||||||
Self::EmbedFailed => StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
Self::MissingSuffix | Self::UnknownLang => StatusCode::BAD_REQUEST,
|
|
||||||
Self::Busy => StatusCode::GATEWAY_TIMEOUT,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
pub(crate) mod api;
|
|
||||||
pub(crate) mod errors;
|
|
||||||
118
src/v2/api.rs
118
src/v2/api.rs
@@ -1,118 +0,0 @@
|
|||||||
use hora::{core::ann_index::ANNIndex, index::hnsw_idx::HNSWIndex};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use tracing::{error, info};
|
|
||||||
use tree_sitter::Parser;
|
|
||||||
|
|
||||||
use super::{errors::Error, mutation};
|
|
||||||
use actix_web::{Responder, post, web};
|
|
||||||
use anyhow::Result;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
pub struct State {
|
|
||||||
pub dict: HashMap<String, HNSWIndex<f32, usize>>,
|
|
||||||
pub mutations_collection: Vec<mutation::MutationCollection>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct SnippetRequest {
|
|
||||||
desc: String,
|
|
||||||
body: String,
|
|
||||||
top_k: Option<usize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
pub struct SnippetResponse {
|
|
||||||
id: usize,
|
|
||||||
snippet: Snippet,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
pub struct Snippet {
|
|
||||||
lang: String,
|
|
||||||
desc: String,
|
|
||||||
body: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_lang(s: &str) -> Result<tree_sitter::Language, Error> {
|
|
||||||
Ok(match s {
|
|
||||||
"go" => tree_sitter_go::LANGUAGE,
|
|
||||||
"c" => tree_sitter_c::LANGUAGE,
|
|
||||||
"rust" => tree_sitter_rust::LANGUAGE,
|
|
||||||
_ => return Err(Error::UnknownLang),
|
|
||||||
}
|
|
||||||
.into())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/api/v2/get")]
|
|
||||||
pub(crate) async fn get_snippet(
|
|
||||||
data: web::Data<crate::state::StateWrapper>,
|
|
||||||
snippet_request: web::Json<SnippetRequest>,
|
|
||||||
) -> Result<impl Responder, Error> {
|
|
||||||
let Some((prompt, lang)) = snippet_request.desc.rsplit_once(" in ") else {
|
|
||||||
return Err(Error::MissingSuffix);
|
|
||||||
};
|
|
||||||
|
|
||||||
let closest = closest_mutation(
|
|
||||||
lang,
|
|
||||||
prompt,
|
|
||||||
snippet_request.body.as_str(),
|
|
||||||
snippet_request.top_k.unwrap_or(1),
|
|
||||||
&data,
|
|
||||||
)?;
|
|
||||||
Ok(web::Json(closest))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn closest_mutation(
|
|
||||||
lang: &str,
|
|
||||||
prompt: &str,
|
|
||||||
body: &str,
|
|
||||||
top_k: usize,
|
|
||||||
data: &web::Data<crate::state::StateWrapper>,
|
|
||||||
) -> Result<Vec<String>, Error> {
|
|
||||||
let langfn = get_lang(lang)?;
|
|
||||||
|
|
||||||
info!(prompt = prompt, language = lang, "v2 request");
|
|
||||||
|
|
||||||
let mut appstate = data.inner.lock().map_err(|_| Error::Busy)?;
|
|
||||||
let target = appstate
|
|
||||||
.embed
|
|
||||||
.embed(prompt)
|
|
||||||
.map_err(|_| Error::EmbedFailed)?;
|
|
||||||
let mut parser = Parser::new();
|
|
||||||
parser
|
|
||||||
.set_language(&langfn)
|
|
||||||
.map_err(|_| Error::UnknownLang)?;
|
|
||||||
|
|
||||||
let source_code = body;
|
|
||||||
let source_bytes = source_code.as_bytes();
|
|
||||||
let tree = parser
|
|
||||||
.parse(source_code, None)
|
|
||||||
.ok_or(Error::SnippetParsing)?;
|
|
||||||
let root_node = tree.root_node();
|
|
||||||
|
|
||||||
// search for k nearest neighbors
|
|
||||||
let collected = appstate.v2.dict[lang]
|
|
||||||
.search(&target, top_k)
|
|
||||||
.iter()
|
|
||||||
.filter_map(|&index| {
|
|
||||||
let applied = mutation::apply(
|
|
||||||
langfn.clone(),
|
|
||||||
source_bytes,
|
|
||||||
root_node,
|
|
||||||
&appstate.v2.mutations_collection[index],
|
|
||||||
);
|
|
||||||
match applied {
|
|
||||||
Ok(v) => Some(v),
|
|
||||||
Err(e) => {
|
|
||||||
error!(
|
|
||||||
collection_index = index,
|
|
||||||
"failed to apply mutations from collection {}", e
|
|
||||||
);
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// TODO: change the expect to a log
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
Ok(collected)
|
|
||||||
}
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
use actix_web::{
|
|
||||||
HttpResponse, error,
|
|
||||||
http::{StatusCode, header::ContentType},
|
|
||||||
};
|
|
||||||
use derive_more::derive::{Display, Error};
|
|
||||||
use serde_json::json;
|
|
||||||
|
|
||||||
#[derive(Debug, Display, Error)]
|
|
||||||
pub enum Error {
|
|
||||||
#[display("the server is busy. come back later.")]
|
|
||||||
Busy,
|
|
||||||
#[display("end your request with ` in somelang`.")]
|
|
||||||
MissingSuffix,
|
|
||||||
#[display("failed to embed your prompt.")]
|
|
||||||
EmbedFailed,
|
|
||||||
#[display("snippets were requested for an unknown language")]
|
|
||||||
UnknownLang,
|
|
||||||
#[display("failed to parse corpus of code to apply mutation on")]
|
|
||||||
SnippetParsing,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl error::ResponseError for Error {
|
|
||||||
fn error_response(&self) -> HttpResponse {
|
|
||||||
let message = json!({
|
|
||||||
"message": self.to_string(),
|
|
||||||
})
|
|
||||||
.to_string();
|
|
||||||
HttpResponse::build(self.status_code())
|
|
||||||
.insert_header(ContentType::json())
|
|
||||||
.body(message)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn status_code(&self) -> StatusCode {
|
|
||||||
match *self {
|
|
||||||
Self::EmbedFailed => StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
Self::MissingSuffix | Self::UnknownLang | Self::SnippetParsing => {
|
|
||||||
StatusCode::BAD_REQUEST
|
|
||||||
}
|
|
||||||
Self::Busy => StatusCode::GATEWAY_TIMEOUT,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
pub(crate) mod api;
|
|
||||||
pub(crate) mod errors;
|
|
||||||
pub(crate) mod mutation;
|
|
||||||
Reference in New Issue
Block a user