78 Commits

Author SHA1 Message Date
Himadri Bhattacharjee
8e045994a8 feat: devShell: add bacon to flake 2025-11-13 19:04:16 +05:30
Himadri Bhattacharjee
48a15defcf feat: add dry-run subcommand to check snippet on sample code 2025-11-13 19:03:59 +05:30
Himadri Bhattacharjee
4fb97f27b0 deps: bump cargo dependencies 2025-11-13 19:03:39 +05:30
Himadri Bhattacharjee
bb6e5fca7e Merge pull request #13 from lavafroth/dependabot/cargo/tracing-subscriber-0.3.20
chore(deps): bump tracing-subscriber from 0.3.19 to 0.3.20
2025-11-13 17:35:05 +05:30
Himadri Bhattacharjee
fd5fb501df Merge pull request #12 from lavafroth/dependabot/cargo/tree-sitter-javascript-0.25.0
chore(deps): bump tree-sitter-javascript from 0.23.1 to 0.25.0
2025-11-13 17:33:55 +05:30
dependabot[bot]
dbeea93010 chore(deps): bump tracing-subscriber from 0.3.19 to 0.3.20
Bumps [tracing-subscriber](https://github.com/tokio-rs/tracing) from 0.3.19 to 0.3.20.
- [Release notes](https://github.com/tokio-rs/tracing/releases)
- [Commits](https://github.com/tokio-rs/tracing/compare/tracing-subscriber-0.3.19...tracing-subscriber-0.3.20)

---
updated-dependencies:
- dependency-name: tracing-subscriber
  dependency-version: 0.3.20
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-09-08 20:56:25 +00:00
dependabot[bot]
af831852d9 chore(deps): bump tree-sitter-javascript from 0.23.1 to 0.25.0
Bumps [tree-sitter-javascript](https://github.com/tree-sitter/tree-sitter-javascript) from 0.23.1 to 0.25.0.
- [Release notes](https://github.com/tree-sitter/tree-sitter-javascript/releases)
- [Commits](https://github.com/tree-sitter/tree-sitter-javascript/compare/v0.23.1...v0.25.0)

---
updated-dependencies:
- dependency-name: tree-sitter-javascript
  dependency-version: 0.25.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-09-08 20:54:52 +00:00
Himadri Bhattacharjee
6eab02575a Merge pull request #10 from lavafroth/dependabot/cargo/clap-4.5.45
chore(deps): bump clap from 4.5.41 to 4.5.45
2025-09-01 05:29:39 +00:00
Himadri Bhattacharjee
80fe8a3f16 Merge pull request #8 from lavafroth/dependabot/cargo/tokenizers-0.21.4
chore(deps): bump tokenizers from 0.21.1 to 0.21.4
2025-09-01 05:24:10 +00:00
Himadri Bhattacharjee
60256c06cc feat: update vscode lsp config 2025-08-26 12:02:58 +05:30
Himadri Bhattacharjee
caf4f51d22 Merge branch 'dump-expression' 2025-08-26 11:59:13 +05:30
Himadri Bhattacharjee
a543e80a04 ver: bump version for next release 2025-08-26 11:58:53 +05:30
Himadri Bhattacharjee
f0c137ade4 fix: reintroduce the root node for anchoring flexibility 2025-08-26 11:57:15 +05:30
Himadri Bhattacharjee
7b0f818d38 feat: parse capture groups with + or * wildcards 2025-08-26 10:41:54 +05:30
dependabot[bot]
4195cbb734 chore(deps): bump clap from 4.5.41 to 4.5.45
Bumps [clap](https://github.com/clap-rs/clap) from 4.5.41 to 4.5.45.
- [Release notes](https://github.com/clap-rs/clap/releases)
- [Changelog](https://github.com/clap-rs/clap/blob/master/CHANGELOG.md)
- [Commits](https://github.com/clap-rs/clap/compare/clap_complete-v4.5.41...clap_complete-v4.5.45)

---
updated-dependencies:
- dependency-name: clap
  dependency-version: 4.5.45
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-08-19 05:05:31 +00:00
Himadri Bhattacharjee
e5602c688c add subcommands for ast quick actions 2025-08-18 16:57:03 +05:30
Himadri Bhattacharjee
89e4c3b5fb docs: document binary releases 2025-08-18 10:29:54 +05:30
Himadri Bhattacharjee
e24e62873f ci: remove macos release currently resulting unusable builds 2025-08-18 09:36:43 +05:30
Himadri Bhattacharjee
0b9ab89f35 ci: release build script
ci: use stable toolchain
2025-08-17 18:13:36 +05:30
Himadri Bhattacharjee
650329206d docs: document new supported languages 2025-08-17 17:37:34 +05:30
Himadri Bhattacharjee
7d9c3a448f feat: cli flag to dump S expression for a source file
TODO: move to being a subcommand
2025-08-17 17:34:43 +05:30
Himadri Bhattacharjee
633c1a206b deps: bump flake 2025-08-13 19:31:42 +05:30
Himadri Bhattacharjee
ab4c62fcf4 lint: clippy 2025-08-13 19:31:33 +05:30
Himadri Bhattacharjee
6ff9ba9d16 feat: add javascript and cpp language support 2025-08-13 19:31:26 +05:30
dependabot[bot]
e348b9a830 chore(deps): bump tokenizers from 0.21.1 to 0.21.4
Bumps [tokenizers](https://github.com/huggingface/tokenizers) from 0.21.1 to 0.21.4.
- [Release notes](https://github.com/huggingface/tokenizers/releases)
- [Changelog](https://github.com/huggingface/tokenizers/blob/main/RELEASE.md)
- [Commits](https://github.com/huggingface/tokenizers/compare/v0.21.1...v0.21.4)

---
updated-dependencies:
- dependency-name: tokenizers
  dependency-version: 0.21.4
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-08-12 06:07:01 +00:00
Himadri Bhattacharjee
d359121afd feat: add support for multiple files in the workdir 2025-08-01 20:29:58 +05:30
Himadri Bhattacharjee
4abd2cffac Merge pull request #5 from lavafroth/dependabot/cargo/clap-4.5.41
chore(deps): bump clap from 4.5.39 to 4.5.41
2025-07-29 07:04:12 +00:00
Himadri Bhattacharjee
daccd63006 feat: remove redundant normalization by token count before l2_norm of embeddings 2025-07-22 19:38:39 +05:30
dependabot[bot]
87e096f0bc Merge pull request #3 from lavafroth/dependabot/cargo/tree-sitter-0.25.8 2025-07-19 14:59:25 +00:00
Himadri Bhattacharjee
91d2640c11 feat: add preview gif 2025-07-19 19:23:58 +05:30
dependabot[bot]
ec3b89f455 chore(deps): bump clap from 4.5.39 to 4.5.41
Bumps [clap](https://github.com/clap-rs/clap) from 4.5.39 to 4.5.41.
- [Release notes](https://github.com/clap-rs/clap/releases)
- [Changelog](https://github.com/clap-rs/clap/blob/master/CHANGELOG.md)
- [Commits](https://github.com/clap-rs/clap/compare/clap_complete-v4.5.39...clap_complete-v4.5.41)

---
updated-dependencies:
- dependency-name: clap
  dependency-version: 4.5.41
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-07-19 08:07:55 +00:00
Himadri Bhattacharjee
c734c81a04 feat: implement shallow globbing; removed dep glob 2025-07-19 13:36:39 +05:30
dependabot[bot]
e7cae348a1 chore(deps): bump tree-sitter from 0.25.6 to 0.25.8
Bumps [tree-sitter](https://github.com/tree-sitter/tree-sitter) from 0.25.6 to 0.25.8.
- [Release notes](https://github.com/tree-sitter/tree-sitter/releases)
- [Commits](https://github.com/tree-sitter/tree-sitter/compare/v0.25.6...v0.25.8)

---
updated-dependencies:
- dependency-name: tree-sitter
  dependency-version: 0.25.8
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-07-17 07:43:28 +00:00
Himadri Bhattacharjee
faea784d8f ci: add dependabot 2025-07-17 13:12:23 +05:30
Himadri Bhattacharjee
e8970f21ff ci: add pipeline to build and test crate 2025-07-17 07:34:50 +00:00
Himadri Bhattacharjee
a1445b2f03 fmt: cargo fmt 2025-07-14 17:36:43 +05:30
Himadri Bhattacharjee
8f5e618841 fix: reshape tokenized f32 vector dimensions to hidden_size 2025-07-14 17:36:43 +05:30
Himadri Bhattacharjee
996142c8dd feat: custom snippet dir support 2025-07-14 17:18:24 +05:30
Himadri Bhattacharjee
b32ed48471 feat: dotdirs for editors 2025-07-14 06:49:05 +05:30
Himadri Bhattacharjee
a4b77cd40b docs: update installation instructions 2025-07-14 06:35:57 +05:30
Himadri Bhattacharjee
c137a6c586 ver: bump version: breaking changes removed REST API 2025-07-14 06:28:38 +05:30
Himadri Bhattacharjee
7e6300f8e9 feat: add vscode-lspconfig settings 2025-07-14 06:27:20 +05:30
Himadri Bhattacharjee
949e5f5df4 docs: remove http api cruft 2025-07-14 06:27:01 +05:30
Himadri Bhattacharjee
08df630a2b feat: remove all actix components 2025-07-09 13:10:51 +05:30
Himadri Bhattacharjee
c8a62b9b42 docs: update docs for generate and refactor directives 2025-07-09 11:42:07 +05:30
Himadri Bhattacharjee
34aeae0496 feat: deprecate language inference from prompt 2025-07-09 08:21:33 +05:30
Himadri Bhattacharjee
bc4f29393e feat: remove language dependence on prompt
breaking change
2025-07-09 08:21:33 +05:30
Himadri Bhattacharjee
dfa414163d chore: extract dimensions into larger scope 2025-07-09 07:55:28 +05:30
Himadri Bhattacharjee
a6b9910d39 feat: add generate feature for lsp 2025-07-08 17:27:15 +05:30
Himadri Bhattacharjee
cf11c5774d fix: parse only the selected text in contextual refactor 2025-07-08 17:05:08 +05:30
Himadri Bhattacharjee
291e2e1832 docs: use smaller sections 2025-07-02 10:57:33 +05:30
Himadri Bhattacharjee
bd772eec48 docs: document lsp mode 2025-07-02 10:49:24 +05:30
Himadri Bhattacharjee
56dc5f4393 ver: 2.0.0 merge branch 'lsp' 2025-07-02 10:20:31 +05:30
Himadri Bhattacharjee
864c394ed7 ver: 2.0.0 lsp mode introduces breaking changes 2025-07-02 10:19:38 +05:30
Himadri Bhattacharjee
55e915cc32 feat: shard lsp module 2025-07-02 10:15:52 +05:30
Himadri Bhattacharjee
5dadaca767 feat: v1 api: better error handling 2025-07-02 09:53:57 +05:30
Himadri Bhattacharjee
efec8c2220 refactor: decouple cli args and embed module 2025-07-01 19:07:07 +05:30
Himadri Bhattacharjee
716b9ed3e2 refactor: move common closest mutation search into a function 2025-07-01 18:58:18 +05:30
Himadri Bhattacharjee
4b710e8675 feat: add mode handler for http or lsp 2025-06-30 07:35:32 +05:30
Himadri Bhattacharjee
645a987cf1 feat: scaffolding for lsp
very sharp edges
2025-06-29 19:58:30 +05:30
Himadri Bhattacharjee
cae538f27b chore: lockfile 2025-06-27 21:59:47 +05:30
Himadri Bhattacharjee
9bc7614ece docs: mention supported languages 2025-06-27 21:59:47 +05:30
Himadri Bhattacharjee
d66be83d2b feat: do not require root definition 2025-06-27 21:59:47 +05:30
Himadri Bhattacharjee
caaa89fecf feat: add an example for v2 api 2025-06-27 21:59:47 +05:30
Himadri Bhattacharjee
19d82da5e2 docs: document v2 api 2025-06-27 21:59:47 +05:30
Himadri Bhattacharjee
3321bd4552 feat: better error handling and logging 2025-06-27 21:53:33 +05:30
Himadri Bhattacharjee
ca8d8e6a59 fix: single purpose mutation 2025-06-27 21:53:33 +05:30
Himadri Bhattacharjee
10048ffb04 feat: move to tracing and tracing-subscriber for logs 2025-06-27 21:53:32 +05:30
Himadri Bhattacharjee
7174ed5e7e fix: find the directory basename for path_to_parent_base 2025-06-27 21:53:21 +05:30
Himadri Bhattacharjee
8470255c5f Merge pull request #1 from lavafroth/v2-exp
V2 experimental tree-sitter AST manipulation API
2025-06-27 21:53:21 +05:30
Himadri Bhattacharjee
e91333a2ca Merge branch 'master' into v2-exp 2025-06-21 11:48:12 +05:30
Himadri Bhattacharjee
816f758e67 chore: move print macros to span for logging later 2025-06-21 11:45:03 +05:30
Himadri Bhattacharjee
f07a688708 ver: update the package version in cargo.toml 2025-06-21 11:45:03 +05:30
Himadri Bhattacharjee
11356413ab chore: remove needless regex crate 2025-06-21 11:37:17 +05:30
Himadri Bhattacharjee
14318cf2fb feat: better error handling for api v2 2025-06-21 11:13:36 +05:30
Himadri Bhattacharjee
7e03d5a346 feat: embed v1 and v2 states in a bigger state struct
pass them individually to API functions when needed
2025-06-20 18:36:50 +05:30
Himadri Bhattacharjee
90e3983f34 feat: v2 parses all kdl definitions via globs 2025-06-20 08:10:33 +05:30
Himadri Bhattacharjee
7ca19deff6 feat: scaffolding for v2 mutation api 2025-06-19 19:27:36 +05:30
33 changed files with 1866 additions and 1310 deletions

8
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,8 @@
version: 2
updates:
- package-ecosystem: "cargo" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
open-pull-requests-limit: 4

26
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,26 @@
on:
release:
types: [created]
jobs:
release:
name: release ${{ matrix.target }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- target: x86_64-pc-windows-gnu
archive: zip
- target: x86_64-unknown-linux-musl
archive: tar.zst
steps:
- uses: actions/checkout@master
- name: Compile and release
uses: rust-build/rust-build.action@v1.4.5
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
RUSTTARGET: ${{ matrix.target }}
ARCHIVE_TYPES: ${{ matrix.archive }}
TOOLCHAIN_VERSION: stable

22
.github/workflows/rust.yml vendored Normal file
View File

@@ -0,0 +1,22 @@
name: Build and test
on:
push:
branches: [ "master" ]
pull_request:
branches: [ "master" ]
env:
CARGO_TERM_COLOR: always
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Build
run: cargo build --verbose
- name: Run tests
run: cargo test --verbose

11
.helix/languages.toml Normal file
View File

@@ -0,0 +1,11 @@
[language-server.silos]
command = "./target/debug/silos"
args = ["lsp"]
[[language]]
name = "go"
language-servers = [ { name = "silos" }, "gopls" ]
[[language]]
name = "rust"
language-servers = [ ]

16
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,16 @@
{
"vscode-lspconfig.serverConfigurations": [
{
"name": "silos",
"document_selector": [
{
"language": "go"
}
],
"command": [
"silos"
"lsp"
]
}
]
}

1833
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,20 +1,27 @@
[package]
name = "silos"
version = "1.0.0"
version = "6.0.0"
edition = "2024"
[dependencies]
actix-web = "4.11.0"
anyhow = "1.0.98"
candle-core = "0.9.1"
candle-nn = "0.9.1"
candle-transformers = "0.9.1"
clap = { version = "4.5.39", features = ["derive"] }
derive_more = "2.0.1"
glob = "0.3.2"
clap = { version = "4.5.45", features = ["derive"] }
derive_more = { version = "2.0.1", features = ["display", "error"] }
hf-hub = "0.4.2"
hora = "0.1.1"
kdl = "6.3.4"
serde = "1.0.219"
serde_json = "1.0.140"
tokenizers = "0.21.1"
tokenizers = "0.21.4"
tracing = "0.1.41"
tracing-subscriber = "0.3.20"
tree-sitter = "0.25.8"
tree-sitter-c = "0.24.1"
tree-sitter-go = "0.23.4"
tree-sitter-rust = "0.24.0"
tokio = { version = "1.45.1", features = ["io-std", "macros", "rt", "rt-multi-thread"] }
tower-lsp = "0.20.0"
tree-sitter-javascript = "0.25.0"
tree-sitter-cpp = "0.23.4"

130
README.md
View File

@@ -2,70 +2,132 @@
Dumb, proomptable modular snippet search.
## Getting started
![preview](./assets/preview.gif)
### Installation
## Installation
You can download a binary from releases tab or build the project from source.
### From source
Prerequisites:
- libc
- [rust toolchain](https://rustup.rs)
Clone this repository and enter it
Clone this repository and build it.
``` sh
git clone https://github.com/lavafroth/silos
cd silos
cargo install --git https://github.com/lavafroth/silos
```
### Setup
## Editor support
Add your code snippets as KDL files in the `./snippets/v1/LANGUAGE/` directory, Take a look at the example snippet for golang in `./snippets/v1/go/simple_worker.kdl`.
- Helix: Use the example `.helix` directory provided to run the LSP for files under `./examples/`.
- Neovim: Please follow [the official guide](https://neovim.io/doc/user/lsp.html).
- VSCode: Use the [vscode-lspconfig](https://marketplace.visualstudio.com/items?itemName=whtsht.vscode-lspconfig) extension with the `.vscode/settings.json` provided.
The snippets must conform to the following structure:
Make sure to modify the binary path in the example to where you have it on your system.
``` kdl
desc "a well articulated description of the snippet",
body #"fn main() { println!("The body of the snippet") }"#
## Usage
- Write a comment above a paragraph of code, consider the example in examples/example.go
``` go
resumeFilename := "resume.pdf"
version := 3
// refactor: change the file basename to that of the parent
whereIsMyResume :=
filepath.Base(
documentsDirectory + "CV" + "_v" + strconv.Itoa(version) + "/" + resumeFilename)
```
KDL supports arbitrary raw strings with as many `#`s before and after the quotes to disambiguate them from the string contents.
After adding your snippets, run the server
``` sh
cargo r
```
- The comment must begin with either of
- `generate: `
- `refactor: `
- Select the code to be modified along with the comment above it.
- Trigger code actions. In helix, this is `space`, `a`.
- Select the option called "ask silos."
> [!NOTE]
>
> Embedding defaults to using the CPU. You may use the `--gpu` flag with a GPU number to use a dedicated GPU.
### Usage
## `generate` snippets
An HTTP REST API listens on port 8000 and can be queried for code snippets.
- Stored in the KDL format inside per-language directories under `./snippets/v1`.
- They must conform to the following structure
#### Query a snippet
``` sh
curl http://localhost:8000/api/v1/get --json '{ "desc": "channeled worker in go" }'
``` kdl
desc "describes the snippet"
body #"the snippet itself"#
```
You must add the "in someLanguage" suffix to your query's description field. This is to keep the API design simple for bothIDE and non-IDE users.
KDL supports arbitrary raw strings with as many `#`s before and after the quotes to disambiguate them from the string contents.
#### Add a snippet
See the example snippet `./snippets/v1/go/simple_worker.kdl` in the go programming language.
``` sh
curl http://localhost:8000/api/v1/add --json \
'{ "desc": "Build an asynchronous shared mutable state", "lang": "rust", "body": "let object = Arc::new(Mutex::new(old));" }'
## `refactor` snippets
This API parses code into an AST (Abstract Syntax Tree) via tree-sitter and can perform subsequent mutations.
### Supported Languages
- C
- Rust
- Go
- Javascript
- C++
### Defining mutation collections
``` kdl
description "describes the mutation collection"
mutation {
expression "(some ((beautiful) @adjective) AST expression) @root"
substitute {
literal "hello"
capture "adjective"
literal "world"
}
}
mutation {
expression "(another) @root"
substitute {
literal "multiple mutations work"
literal "as long as their expression"
literal "don't collide"
}
}
```
## v2 API
- `description`: A textual description of the mutation collection.
- `mutation`: Defines individual code changes.
- `expression`: Uses tree-sitter to match and capture AST nodes with `@` prefixes,
- The special `@root` node must be specify the expression to be replaced.
- `substitute`: Constructs the modified code using literals and captured arguments.
Language grammar parsing with abstract syntax tree manipulation support.
See the example mutation collection in `./snippets/v2/go/filepath-parent.kdl`.
Coming soon
- The API performs a single-pass substitution based on the closest matching mutation.
- Captured groups are used within the `substitute` block and the mutated code is returned.
## TODOs
> Every capture group must contain the largest atom to be operated on.
For example: if you wish to operate on elements of an array, capture each identifier inside the array
- [ ] Create an LSP to add the suffix based on filetype.
Correct way: Here the `array` and `identifier` only hints about where the expression `root` lies.
```
(array (identifier @root))
```
Incorrect way: Here the root expression matches the block all the array elements inside the braces, not each element.
```
(array ((identifier)*) @entire-block-capture) @root
```
**Further reading**
- [tree-sitter query snytax](https://tree-sitter.github.io/tree-sitter/using-parsers/queries/1-syntax.html) to create mutation expressions.

BIN
assets/preview.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

17
examples/example.go Normal file
View File

@@ -0,0 +1,17 @@
package main
import (
"fmt"
"path/filepath"
"strconv"
)
func main() {
documentsDirectory := "/home/h/Documents/"
resumeFilename := "resume.pdf"
version := 3
whereIsMyResume :=
filepath.Base(
documentsDirectory + "CV" + "_v" + strconv.Itoa(version) + "/" + resumeFilename)
fmt.Println(whereIsMyResume)
}

6
flake.lock generated
View File

@@ -2,11 +2,11 @@
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1749776303,
"narHash": "sha256-OHibOvVwKqO1qvRg0r3agtd1EagW4THBcoWT7QGgcNo=",
"lastModified": 1755020227,
"narHash": "sha256-gGmm+h0t6rY88RPTaIm3su95QvQIVjAJx558YUG4Id8=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "6e7721e37bf00fa7ea44ac3cfc9d2411284ec3ef",
"rev": "695d5db1b8b20b73292501683a524e0bd79074fb",
"type": "github"
},
"original": {

View File

@@ -13,6 +13,7 @@
packages = with pkgs; [
stdenv.cc.cc
pkg-config
bacon
];
libraries = with pkgs; [

View File

@@ -1,3 +1,3 @@
desc "display all the path entries"
body """printf "%s\n" $PATH"""
body #"printf "%s\n" $PATH"#

View File

@@ -0,0 +1,13 @@
description "base64 import"
mutation {
expression "(import_spec_list ((import_spec)* @spec)) @root"
substitute {
literal "("
literal "\n"
capture "spec"
literal "\n"
literal #""base64""#
literal "\n"
literal ")"
}
}

View File

@@ -0,0 +1,14 @@
description "filepath base to parent's base"
mutation {
expression """
(call_expression
function: (_) @func (#eq? @func "filepath.Base")
arguments: (_) @args
) @root
"""
substitute {
literal "filepath.Base(filepath.Dir(filepath.Clean"
capture "args"
literal "))"
}
}

79
src/args.rs Normal file
View File

@@ -0,0 +1,79 @@
use clap::{Args, Parser, Subcommand};
use std::path::PathBuf;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub(crate) struct Cli {
#[command(subcommand)]
pub command: Command,
}
#[derive(Args, Debug)]
pub(crate) struct Lsp {
/// Run on the Nth GPU device.
#[arg(long)]
pub(crate) gpu: Option<usize>,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
pub(crate) model_id: Option<String>,
/// Revision or branch.
#[arg(long)]
pub(crate) revision: Option<String>,
/// Path to the directory containing `generate` and `refactor` snippets.
#[arg(long, default_value = "./snippets")]
pub(crate) snippets: std::path::PathBuf,
}
#[derive(Args, Debug)]
pub struct DumpExpression {
pub path: PathBuf,
}
#[derive(Args, Debug)]
pub struct ShowCaptures {
pub path: PathBuf,
pub expression: String,
}
#[derive(Args, Debug)]
pub struct DryRun {
pub path: PathBuf,
pub edit_file: PathBuf,
}
#[derive(Subcommand, Debug)]
pub enum Ast {
/// Dump the S expression for a given source file
DumpExpression(DumpExpression),
/// Show what parts of a source file gets captured by an S expression
ShowCaptures(ShowCaptures),
/// Test your edit snippets on a sample file
DryRun(DryRun),
}
#[derive(Subcommand, Debug)]
pub enum Command {
/// quick actions to dump, modify and verify abstract syntax trees
#[command(subcommand)]
Ast(Ast),
/// spawn a language server for use with a text editor
Lsp(Lsp),
}
impl Lsp {
pub(crate) fn resolve_model_and_revision(&self) -> (String, String) {
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
let default_revision = "refs/pr/21".to_string();
match (self.model_id.clone(), self.revision.clone()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_owned()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
}
}
}

View File

@@ -1,4 +1,3 @@
use super::Args;
use anyhow::{Error as E, Result};
use candle_core::Device;
use candle_core::Tensor;
@@ -8,32 +7,55 @@ use hf_hub::Repo;
use hf_hub::RepoType;
use hf_hub::api::sync::Api;
use std::path::PathBuf;
use tokenizers::DecoderWrapper;
use tokenizers::ModelWrapper;
use tokenizers::NormalizerWrapper;
use tokenizers::PostProcessorWrapper;
use tokenizers::PreTokenizerWrapper;
use tokenizers::Tokenizer;
use tokenizers::TokenizerImpl;
pub struct Embed {
model: BertModel,
tokenizer: Tokenizer,
pub hidden_size: usize,
tokenizer: TokenizerImpl<
ModelWrapper,
NormalizerWrapper,
PreTokenizerWrapper,
PostProcessorWrapper,
DecoderWrapper,
>,
}
impl Embed {
pub(crate) fn new(args: Args) -> Result<Self> {
let device = if let Some(gpu_dev) = args.gpu {
pub(crate) fn new(gpu: Option<usize>, model_id: &str, revision: &str) -> Result<Self> {
let device = if let Some(gpu_dev) = gpu {
Device::new_cuda(gpu_dev)?
} else {
Device::Cpu
};
let (model_id, revision) = args.resolve_model_and_revision();
let (config_path, tokenizer_path, weights_path) =
Self::download_model_files(&model_id, &revision)?;
Self::download_model_files(model_id, revision)?;
let config = std::fs::read_to_string(config_path)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
let mut tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, &device)? };
let model = BertModel::load(vb, &config)?;
Ok(Embed { model, tokenizer })
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?
.clone();
Ok(Embed {
model,
tokenizer,
hidden_size: config.hidden_size,
})
}
fn download_model_files(model_id: &str, revision: &str) -> Result<(PathBuf, PathBuf, PathBuf)> {
@@ -47,14 +69,9 @@ impl Embed {
Ok((config, tokenizer, weights))
}
pub(crate) fn embed(&mut self, prompt: &str) -> Result<Vec<f32>> {
let tokenizer = self
pub(crate) fn embed(&self, prompt: &str) -> Result<Vec<f32>> {
let tokens = self
.tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
@@ -64,9 +81,9 @@ impl Embed {
let token_type_ids = token_ids.zeros_like()?;
let embeddings = self.model.forward(&token_ids, &token_type_ids, None)?;
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = normalize_l2(&embeddings)?.reshape(384)?.to_vec1::<f32>()?;
let embeddings = normalize_l2(&embeddings.sum(1)?)?
.reshape(self.hidden_size)?
.to_vec1::<f32>()?;
Ok(embeddings)
}
}

186
src/lsp.rs Normal file
View File

@@ -0,0 +1,186 @@
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use tower_lsp::lsp_types::*;
use tower_lsp::{Client, LanguageServer};
pub struct Backend {
pub client: Client,
pub body: Arc<Mutex<HashMap<Url, String>>>,
pub appstate: crate::State,
}
fn string_range_index(s: &str, r: Range) -> &str {
let mut newline_count = 0;
let mut start = None;
let mut end = None;
for (i, c) in s.chars().enumerate() {
if newline_count == r.start.line && start.is_none() {
start.replace(i + r.start.character as usize);
}
if newline_count == r.end.line && end.is_none() {
end.replace(i + r.end.character as usize);
}
if c == '\n' {
newline_count += 1;
}
}
&s[start.unwrap_or_default()..end.unwrap_or(s.len())]
}
#[tower_lsp::async_trait]
impl LanguageServer for Backend {
async fn initialize(
&self,
_: InitializeParams,
) -> tower_lsp::jsonrpc::Result<InitializeResult> {
Ok(InitializeResult {
capabilities: ServerCapabilities {
text_document_sync: Some(TextDocumentSyncCapability::Kind(
TextDocumentSyncKind::FULL,
)),
code_action_provider: Some(CodeActionProviderCapability::Options(
CodeActionOptions::default(),
)),
..Default::default()
},
..Default::default()
})
}
async fn initialized(&self, _: InitializedParams) {
self.client
.log_message(MessageType::INFO, "server initialized!")
.await;
}
async fn shutdown(&self) -> tower_lsp::jsonrpc::Result<()> {
Ok(())
}
async fn did_open(&self, params: DidOpenTextDocumentParams) {
self.body
.lock()
.await
.insert(params.text_document.uri, params.text_document.text);
}
async fn did_change(&self, params: DidChangeTextDocumentParams) {
if let Some(body) = params.content_changes.into_iter().next() {
self.body
.lock()
.await
.insert(params.text_document.uri, body.text);
}
}
async fn code_action(
&self,
params: CodeActionParams,
) -> tower_lsp::jsonrpc::Result<Option<CodeActionResponse>> {
let uri = params.text_document.uri;
let Some(lang) = url_extension(&uri) else {
self.client
.log_message(
MessageType::ERROR,
"unable to determine filetype, file has no extension",
)
.await;
return Ok(None);
};
let body_locked = self.body.lock().await;
let Some(body) = body_locked.get(&uri) else {
return Ok(None);
};
let mut range = params.range;
let selected_text = string_range_index(body, range);
let Some(comment) = ParsedAction::new(selected_text) else {
return Ok(None);
};
let action_response = match comment.action {
Action::Generate => {
range.start = range.end;
self.appstate
.generate(&lang, comment.description, 1)
.map(|v| v.into_iter().map(|s| format!("{s}\n")).collect())
.map_err(|e| e.to_string())
}
Action::Refactor => self
.appstate
.refactor(&lang, comment.description, selected_text, 1)
.map_err(|e| e.to_string()),
};
let closest_matches = match action_response {
Ok(v) => v,
Err(e) => {
self.client
.log_message(MessageType::ERROR, e.to_string())
.await;
return Ok(None);
}
};
let Some(new_text) = closest_matches.into_iter().next() else {
return Ok(None);
};
let text_edit = TextEdit { range, new_text };
let changes: HashMap<Url, _> = [(uri, vec![text_edit])].into_iter().collect();
let edit = Some(WorkspaceEdit {
changes: Some(changes),
..Default::default()
});
let actions = vec![CodeActionOrCommand::CodeAction(CodeAction {
title: "ask silos".to_string(),
edit,
..Default::default()
})];
Ok(Some(actions))
}
}
pub struct ParsedAction<'a> {
action: Action,
description: &'a str,
}
pub enum Action {
Generate,
Refactor,
}
impl<'a> ParsedAction<'a> {
fn new(comment: &'a str) -> Option<ParsedAction<'a>> {
let upto_newline = match comment.rsplit_once("\n") {
Some((upto_newline, _discard)) => upto_newline,
None => comment,
};
let maybe_generate =
upto_newline
.split_once("generate: ")
.map(|(_discard, description)| ParsedAction {
action: Action::Generate,
description,
});
let maybe_refactor =
upto_newline
.split_once("refactor: ")
.map(|(_discard, description)| ParsedAction {
action: Action::Refactor,
description,
});
maybe_generate.or(maybe_refactor)
}
}
fn url_extension(u: &Url) -> Option<String> {
let file_path = u.to_file_path().ok()?;
let extension = file_path.extension()?;
let extension = extension.to_str()?;
Some(extension.to_string())
}

View File

@@ -1,86 +1,86 @@
use actix_web::{App, HttpServer, web};
use anyhow::{Context, Error as E, Result};
use clap::Parser;
use hora::core::ann_index::ANNIndex;
use hora::core::{ann_index::ANNIndex, metrics::Metric::Euclidean};
use hora::index::hnsw_idx::HNSWIndex;
use kdl::KdlDocument;
use state::{State, dump_expression};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use tower_lsp::{LspService, Server};
mod args;
mod embed;
mod v1;
// mod v2;
mod lsp;
mod mutation;
mod sources;
mod state;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on the Nth GPU device.
#[arg(long)]
gpu: Option<usize>,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
/// Revision or branch.
#[arg(long)]
revision: Option<String>,
/// The port for the API to listen on
#[arg(long, default_value = "8000")]
port: u16,
}
impl Args {
fn resolve_model_and_revision(&self) -> (String, String) {
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
let default_revision = "refs/pr/21".to_string();
match (self.model_id.clone(), self.revision.clone()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_owned()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
}
}
}
#[actix_web::main]
#[tokio::main]
async fn main() -> Result<()> {
let args = Args::parse();
let port = args.port;
let mut embed = embed::Embed::new(args)?;
tracing_subscriber::fmt::init();
let args = match args::Cli::parse().command {
args::Command::Ast(ast) => {
match ast {
args::Ast::DumpExpression(source_file) => {
println!("{}", dump_expression(&source_file.path)?);
}
args::Ast::ShowCaptures(show_captures) => {
let source_bytes = std::fs::read(&show_captures.path)?;
let langfn = state::lang_from_file_extension(&show_captures.path)?;
let tree = state::parse_into_tree(&source_bytes, &langfn)?;
let root_node = tree.root_node();
let cooked = mutation::query(
root_node,
&show_captures.expression,
&langfn,
&source_bytes,
);
println!("{:#?}", cooked);
}
args::Ast::DryRun(dry_run) => {
let mutation_collection = mutation::from_path(dry_run.edit_file)?;
let source_bytes = std::fs::read(&dry_run.path)?;
let langfn = state::lang_from_file_extension(&dry_run.path)?;
let tree = state::parse_into_tree(&source_bytes, &langfn)?;
let root_node = tree.root_node();
let cooked =
mutation::apply(langfn, &source_bytes, root_node, &mutation_collection)?;
println!("{cooked}");
}
}
return Ok(());
}
args::Command::Lsp(lsp) => lsp,
};
let (model_id, revision) = args.resolve_model_and_revision();
let embed = embed::Embed::new(args.gpu, &model_id, &revision)?;
let mut dict = HashMap::default();
let dimensions = embed.hidden_size;
let paths = glob::glob("./snippets/v1/*/*.kdl")?;
for path in paths {
let path = path?;
let parent = path
.components()
.rev()
.nth(1)
.unwrap()
.as_os_str()
.to_string_lossy()
.to_string();
for (language, paths) in sources::rule_files(args.snippets.join("generate"))? {
for path in paths {
let current_lang_index = dict
.entry(language.clone())
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
let current_lang_index = dict.entry(parent).or_insert_with(|| {
let dimension = 384;
let params = hora::index::hnsw_params::HNSWParams::<f32>::default();
let doc_str = std::fs::read_to_string(&path)?;
let doc: KdlDocument = doc_str
.parse()
.context(format!("failed to parse KDL: {}", path.display()))?;
HNSWIndex::<f32, String>::new(dimension, &params)
});
let doc_str = std::fs::read_to_string(path)?;
let doc: KdlDocument = doc_str.parse().context("failed to parse KDL")?;
let Some(desc) = doc.get_arg("desc").and_then(|v| v.as_string()) else {
continue;
};
let Some(body) = doc.get_arg("body").and_then(|v| v.as_string()) else {
continue;
};
current_lang_index
.add(&embed.embed(desc)?, body.to_string())
.map_err(E::msg)?;
let Some(desc) = doc.get_arg("desc").and_then(|v| v.as_string()) else {
continue;
};
let Some(body) = doc.get_arg("body").and_then(|v| v.as_string()) else {
continue;
};
current_lang_index
.add(&embed.embed(desc)?, body.to_string())
.map_err(E::msg)?;
}
}
for index in dict.values_mut() {
@@ -88,18 +88,47 @@ async fn main() -> Result<()> {
.build(hora::core::metrics::Metric::Euclidean)
.map_err(E::msg)?;
}
let appstate = v1::api::AppState { dict, embed };
let appstate_wrapped = web::Data::new(appstate.build());
let mut refactor_dict = HashMap::new();
let mut mutations_collection = vec![];
for (language, paths) in sources::rule_files(args.snippets.join("refactor"))? {
for path in paths {
let mutations = mutation::from_path(path)?;
let current_lang_index = refactor_dict
.entry(language.clone())
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
HttpServer::new(move || {
App::new()
.app_data(appstate_wrapped.clone())
.service(v1::api::get_snippet)
.service(v1::api::add_snippet)
})
.bind(("127.0.0.1", port))?
.run()
.await
.map_err(E::from)
current_lang_index
.add(
&embed.embed(&mutations.description)?,
mutations_collection.len(),
)
.map_err(E::msg)?;
mutations_collection.push(mutations);
}
}
for index in refactor_dict.values_mut() {
index.build(Euclidean).map_err(E::msg)?;
}
let appstate = State::new(
embed,
state::Generate { dict },
state::Refactor {
dict: refactor_dict,
mutations_collection,
},
);
let stdin = tokio::io::stdin();
let stdout = tokio::io::stdout();
let (service, socket) = LspService::new(|client| lsp::Backend {
client,
body: Arc::new(Mutex::new(HashMap::default())),
appstate,
});
Server::new(stdin, stdout, socket).serve(service).await;
Ok(())
}

210
src/mutation.rs Normal file
View File

@@ -0,0 +1,210 @@
use std::collections::HashMap;
use std::path::Path;
use tracing::debug;
use tree_sitter::{Language, Node, Query, QueryCursor, StreamingIterator};
use anyhow::{Result, bail};
use kdl::KdlDocument;
#[derive(Debug)]
pub struct Mutation {
pub expression: String,
pub substitute: Vec<Substitute>,
}
pub struct MutationCollection {
pub description: String,
pub mutations: Vec<Mutation>,
}
#[derive(Debug)]
pub enum Substitute {
Literal(String),
Capture(String),
}
pub fn from_path<P: AsRef<Path>>(path: P) -> Result<MutationCollection> {
let contents = std::fs::read_to_string(path)?;
let doc: KdlDocument = contents.parse()?;
let mut mutations = vec![];
let mut description = None;
for node in doc.nodes() {
let node_name = node.name().value();
if node_name != "mutation" && node_name != "description" {
bail!(
"document root must only contain `mutation` or `description` nodes: got {node_name}"
);
}
if node_name == "description" {
description.replace(
node.entry(0)
.unwrap()
.value()
.as_string()
.unwrap()
.to_string(),
);
continue;
}
let node = node.children().unwrap();
let Some(expression) = node.get_arg("expression").and_then(|v| v.as_string()) else {
bail!("mutation node must contain an expression");
};
let Some(substitute) = node.get("substitute") else {
bail!("mutation node must contain an substitute");
};
let children = substitute.children().unwrap().nodes();
let mut substitute = vec![];
for child in children {
let attrib = child.entry(0).unwrap().value().as_string().unwrap();
let substitutor = match child.name().value() {
"literal" => Substitute::Literal(attrib.to_string()),
"capture" => Substitute::Capture(attrib.to_string()),
_ => unreachable!(),
};
substitute.push(substitutor);
}
mutations.push(Mutation {
expression: expression.to_string(),
substitute,
})
}
let Some(description) = description else {
bail!("mutation collection contains no `description`");
};
Ok(MutationCollection {
description,
mutations,
})
}
pub fn apply(
lang: Language,
source_bytes: &[u8],
root_node: Node<'_>,
mutations: &MutationCollection,
) -> Result<String, anyhow::Error> {
let mut split_ats = vec![];
let mut query_result_map = HashMap::new();
for mutation in &mutations.mutations {
for query_result in query(root_node, mutation.expression.as_str(), &lang, source_bytes) {
debug!("mutation query expression matched: {query_result:?}");
split_ats.push(query_result.start);
split_ats.push(query_result.end);
let mut ast_rewrite = String::default();
for sub in &mutation.substitute {
ast_rewrite.push_str(match sub {
Substitute::Literal(attrib) => attrib,
Substitute::Capture(attrib) => &query_result.captures[attrib],
})
}
debug!("AST rewritten to {ast_rewrite:?}");
query_result_map.insert(query_result.start, ast_rewrite);
}
}
split_ats.sort();
let splits = split_at_indices(source_bytes, &split_ats);
let mut output = String::default();
for (i, split) in splits.indices.iter().zip(splits.values) {
let split = std::str::from_utf8(split)?;
output.push_str(query_result_map.get(i).map(|v| v.as_str()).unwrap_or(split));
}
Ok(output)
}
#[derive(Debug)]
pub struct QueryCooked {
captures: HashMap<String, String>,
end: usize,
start: usize,
}
pub struct SplitMap<'a> {
values: Vec<&'a [u8]>,
indices: Vec<usize>,
}
fn split_at_indices<'a>(c: &'a [u8], idx: &[usize]) -> SplitMap<'a> {
let mut a = 0;
let mut values = vec![];
let mut indices = vec![a];
for &b in idx {
values.push(&c[a..b]);
a = b;
indices.push(a);
}
values.push(&c[a..]);
assert_eq!(values.len(), indices.len());
SplitMap { values, indices }
}
pub fn query<'a>(
node: Node<'a>,
expr: &'a str,
lang: &Language,
source_bytes: &[u8],
) -> Vec<QueryCooked> {
let query = Query::new(lang, expr).unwrap();
let mut qc = QueryCursor::new();
let mut query_matches = qc.matches(&query, node, source_bytes);
let capture_names = query.capture_names();
// println!("names: {capture_names:#?}");
let mut cooked = vec![];
while let Some(matcha) = query_matches.next() {
let mut capture_cooked = HashMap::new();
let mut start = 0;
let mut end = 0;
if matcha.captures.is_empty() {
continue;
}
// println!("match {:#?}", matcha.id());
for (ix, name) in capture_names.iter().enumerate() {
let nodes = matcha.nodes_for_capture_index(ix.try_into().unwrap());
let mut start_pos = None;
let mut end_pos = None;
debug!("matches for {name}");
for node in nodes {
start_pos.get_or_insert(node.start_byte());
end_pos.replace(node.end_byte());
debug!("hit {node:#?}");
}
let (Some(start_pos), Some(end_pos)) = (start_pos, end_pos) else {
continue;
};
if *name == "root" {
start = start_pos;
end = end_pos;
}
let text_bytes = &source_bytes[start_pos..end_pos];
let text = std::str::from_utf8(text_bytes).unwrap();
// println!("text: {text}");
capture_cooked.insert(name.to_string(), text.to_string());
}
cooked.push(QueryCooked {
start,
end,
captures: capture_cooked,
})
}
cooked
}

34
src/sources.rs Normal file
View File

@@ -0,0 +1,34 @@
use std::{
collections::HashMap,
fs, io,
path::{Path, PathBuf},
};
pub fn rule_files<P: AsRef<Path>>(path: P) -> io::Result<HashMap<String, Vec<PathBuf>>> {
let per_language_dirs: Vec<_> = fs::read_dir(path)?
.filter_map(|res| res.ok())
.map(|direntry| direntry.path())
.filter(|dir| dir.is_dir())
.collect();
let mut basename_to_paths = HashMap::new();
for language_dir in per_language_dirs {
let Some(dirname) = language_dir
.file_stem()
.and_then(|v| v.to_str())
.map(|v| v.to_string())
else {
continue;
};
let rule_file_paths: Vec<_> = fs::read_dir(&language_dir)?
.filter_map(|res| res.ok())
.map(|entry| entry.path())
.filter(|file| file.is_file() && file.extension().is_some_and(|ext| ext == "kdl"))
.map(|path| path.to_path_buf())
.collect();
basename_to_paths.insert(dirname, rule_file_paths);
}
Ok(basename_to_paths)
}
// fn prebuilt_index();

155
src/state.rs Normal file
View File

@@ -0,0 +1,155 @@
use crate::mutation;
use derive_more::Display;
use derive_more::Error;
use hora::core::ann_index::ANNIndex;
use hora::index::hnsw_idx::HNSWIndex;
use std::collections::HashMap;
use std::path::Path;
use tree_sitter::Parser;
#[derive(Debug, Display, Error)]
pub enum Error {
#[display("failed to embed prompt")]
EmbedFailed,
#[display("snippets were requested for an unknown language")]
UnknownLang,
#[display("failed to parse corpus of code to apply mutation on")]
SnippetParsing,
}
pub struct Refactor {
pub dict: HashMap<String, HNSWIndex<f32, usize>>,
pub mutations_collection: Vec<mutation::MutationCollection>,
}
impl Refactor {
pub fn search(
&self,
lang: &str,
target: &[f32],
body: &str,
top_k: usize,
) -> Result<Vec<String>, Error> {
let langfn = lang_from_name(lang)?;
let source_bytes = body.as_bytes();
let tree = parse_into_tree(source_bytes, &langfn)?;
let root_node = tree.root_node();
// search for k nearest neighbors
let collected = self.dict[lang]
.search(target, top_k)
.iter()
.filter_map(|&index| {
let applied = mutation::apply(
langfn.clone(),
source_bytes,
root_node,
&self.mutations_collection[index],
);
match applied {
Ok(v) => Some(v),
Err(e) => {
tracing::error!(
collection_index = index,
"failed to apply mutations from collection {}",
e
);
None
}
}
})
.collect();
Ok(collected)
}
}
pub fn lang_from_name(s: &str) -> Result<tree_sitter::Language, Error> {
Ok(match s {
"go" => tree_sitter_go::LANGUAGE,
"c" | "h" => tree_sitter_c::LANGUAGE,
"cpp" | "hpp" => tree_sitter_cpp::LANGUAGE,
"js" | "ts" => tree_sitter_javascript::LANGUAGE,
"rs" => tree_sitter_rust::LANGUAGE,
_ => return Err(Error::UnknownLang),
}
.into())
}
pub fn lang_from_file_extension(path: &Path) -> Result<tree_sitter::Language, Error> {
let Some(lang) = path.extension() else {
return Err(Error::UnknownLang);
};
let lang = lang.to_str().ok_or(Error::UnknownLang)?;
lang_from_name(lang)
}
// parses `body` written in the language `langfn` into tree sitter AST
pub fn parse_into_tree(
body: &[u8],
langfn: &tree_sitter::Language,
) -> Result<tree_sitter::Tree, Error> {
let mut parser = Parser::new();
parser
.set_language(langfn)
.map_err(|_| Error::UnknownLang)?;
let tree = parser.parse(body, None).ok_or(Error::SnippetParsing)?;
Ok(tree)
}
pub fn dump_expression(path: &Path) -> Result<String, Error> {
let source_bytes = std::fs::read(path).map_err(|_| Error::SnippetParsing)?;
let tree = parse_into_tree(&source_bytes, &lang_from_file_extension(path)?)?;
Ok(tree.root_node().to_sexp().to_string())
}
pub struct Generate {
pub dict: HashMap<String, HNSWIndex<f32, String>>,
}
impl Generate {
fn search(&self, lang: &str, target: &[f32], top_k: usize) -> Result<Vec<String>, Error> {
let Some(snippets_for_lang) = self.dict.get(lang) else {
return Err(Error::UnknownLang);
};
Ok(snippets_for_lang.search(target, top_k))
}
}
pub struct State {
embed: crate::embed::Embed,
generate: Generate,
refactor: Refactor,
}
impl State {
pub fn new(embed: crate::embed::Embed, generate: Generate, refactor: Refactor) -> Self {
Self {
embed,
generate,
refactor,
}
}
pub fn generate(&self, lang: &str, prompt: &str, top_k: usize) -> Result<Vec<String>, Error> {
let Ok(target) = self.embed.embed(prompt) else {
return Err(Error::EmbedFailed);
};
self.generate.search(lang, &target, top_k)
}
pub fn refactor(
&self,
lang: &str,
prompt: &str,
body: &str,
top_k: usize,
) -> Result<Vec<String>, Error> {
let Ok(target) = self.embed.embed(prompt) else {
return Err(Error::EmbedFailed);
};
self.refactor.search(lang, &target, body, top_k)
}
}

View File

@@ -1,106 +0,0 @@
use hora::core::ann_index::ANNIndex;
use std::{collections::HashMap, sync::Mutex};
use hora::index::hnsw_idx::HNSWIndex;
use serde::{Deserialize, Serialize};
use crate::embed;
use super::errors::GetError;
use actix_web::{Responder, post, web};
use anyhow::Result;
#[derive(Deserialize)]
pub struct SnippetRequest {
desc: String,
top_k: Option<usize>,
}
#[derive(Deserialize, Debug)]
pub struct SnippetOnDisk {
pub body: String,
pub desc: String,
}
pub struct AppStateWrapper {
inner: Mutex<AppState>,
}
pub struct AppState {
pub dict: HashMap<String, HNSWIndex<f32, String>>,
pub embed: embed::Embed,
}
impl AppState {
pub fn build(self) -> AppStateWrapper {
AppStateWrapper {
inner: Mutex::new(self),
}
}
}
#[derive(Serialize)]
pub struct SnippetResponse {
id: usize,
snippet: Snippet,
}
#[derive(Serialize, Deserialize)]
pub struct Snippet {
lang: String,
desc: String,
body: String,
}
#[post("/api/v1/get")]
pub(crate) async fn get_snippet(
data: web::Data<AppStateWrapper>,
snippet_request: web::Json<SnippetRequest>,
) -> Result<impl Responder, GetError> {
let Some((prompt, lang)) = snippet_request.desc.rsplit_once(" in ") else {
return Err(GetError::MissingSuffix);
};
let Ok(mut appstate) = data.inner.lock() else {
return Err(GetError::Busy);
};
let Ok(target) = appstate.embed.embed(prompt) else {
return Err(GetError::EmbedFailed);
};
// search for k nearest neighbors
let closest: Vec<String> =
appstate.dict[lang].search(&target, snippet_request.top_k.unwrap_or(1));
Ok(web::Json(closest))
}
#[post("/api/v1/add")]
pub(crate) async fn add_snippet(
data: web::Data<AppStateWrapper>,
snippet: web::Json<Snippet>,
) -> Result<impl Responder, GetError> {
let Ok(mut appstate) = data.inner.lock() else {
return Err(GetError::Busy);
};
let Ok(embedding) = appstate.embed.embed(&snippet.desc) else {
return Err(GetError::EmbedFailed);
};
let index = appstate
.dict
.entry(snippet.lang.clone())
.or_insert_with(|| {
let dimension = 384;
let params = hora::index::hnsw_params::HNSWParams::<f32>::default();
HNSWIndex::<f32, String>::new(dimension, &params)
});
index.add(&embedding, snippet.body.clone()).unwrap();
index.build(hora::core::metrics::Metric::Euclidean).unwrap();
Ok(format!(
"{} {} {}",
snippet.body, snippet.lang, snippet.desc
))
}

View File

@@ -1,36 +0,0 @@
use actix_web::{
HttpResponse, error,
http::{StatusCode, header::ContentType},
};
use derive_more::derive::{Display, Error};
use serde_json::json;
#[derive(Debug, Display, Error)]
pub enum GetError {
#[display("the server is busy. come back later.")]
Busy,
#[display("end your request with ` in somelang`.")]
MissingSuffix,
#[display("failed to embed your prompt.")]
EmbedFailed,
}
impl error::ResponseError for GetError {
fn error_response(&self) -> HttpResponse {
let message = json!({
"message": self.to_string(),
})
.to_string();
HttpResponse::build(self.status_code())
.insert_header(ContentType::json())
.body(message)
}
fn status_code(&self) -> StatusCode {
match *self {
Self::EmbedFailed => StatusCode::INTERNAL_SERVER_ERROR,
Self::MissingSuffix => StatusCode::BAD_REQUEST,
Self::Busy => StatusCode::GATEWAY_TIMEOUT,
}
}
}

View File

@@ -1,2 +0,0 @@
pub(crate) mod api;
pub(crate) mod errors;